From ba61109b6c8409f0b38da47b2825cf4b7bacd26b Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Fri, 26 Aug 2022 14:23:30 +0800 Subject: [PATCH] [FAW] remove code related to chunk (#1501) --- .../layers/cache_embedding/cache_mgr.py | 53 +++++++++---------- .../cache_embedding/freq_aware_embedding.py | 6 +-- 2 files changed, 27 insertions(+), 32 deletions(-) diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index 274c6bb92..19fe5d35d 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -56,7 +56,6 @@ class CachedParamMgr(torch.nn.Module): self.num_hits_history = [] self.num_miss_history = [] self.num_write_back_history = [] - self.input_id_percent_in_load_chunk = [] self._reset_comm_stats() self._evict_strategy = evict_strategy @@ -156,23 +155,23 @@ class CachedParamMgr(torch.nn.Module): # self.cuda_cached_weight = self.weight raise NotImplementedError() - def cpu_weight_data(self, chunk_id: int) -> torch.Tensor: + def cpu_weight_data(self, row_idx: int) -> torch.Tensor: """ - access a chunk of CPU weight. + access a row of CPU weight. Args: - chunk_id (int): chunk id + row_idx (int): the idx of rows Returns: - torch.Tensor: a piece of memory in CPU weight corresponding to chunk id's payload. The tensor is 1-D. + torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D. """ return self.weight.data.view(-1).narrow(0, - int(chunk_id) * self.embedding_dim, + int(row_idx) * self.embedding_dim, self.embedding_dim).view(1, self.embedding_dim) @property - def cuda_available_chunk_num(self): + def cuda_available_row_num(self): return self._cuda_available_row_num @torch.no_grad() @@ -202,7 +201,7 @@ class CachedParamMgr(torch.nn.Module): preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings) if preload_row_num > 0: with Timer() as timer: - # extract chunks from cpu weight + # extract rows from cpu weight preload_row_ids = torch.arange(preload_row_num) preload_slot_ids = preload_row_ids.cuda() @@ -213,8 +212,8 @@ class CachedParamMgr(torch.nn.Module): src=self.weight.view(self.num_embeddings, -1), tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) else: - preload_chunks = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda() - self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_chunks) + preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda() + self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_rows) # update auxiliary info slot_offsets = preload_slot_ids @@ -224,15 +223,15 @@ class CachedParamMgr(torch.nn.Module): print(f'Cache warmup finished cost {timer.elapsed} sec.') def flush(self): - """flush all CUDA chunks to CPU. + """flush all CUDA rows to CPU. The function is usually called after training finished. """ slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1) - chunk_ids = self.cached_idx_map[slots] - chunks = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu() - self.weight.view(self.num_embeddings, -1).index_copy_(0, chunk_ids.cpu(), chunks) + row_ids = self.cached_idx_map[slots] + rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu() + self.weight.view(self.num_embeddings, -1).index_copy_(0, row_ids.cpu(), rows) self.cached_idx_map.index_fill_(0, slots, -1) - self.inverted_cached_idx.index_fill_(0, chunk_ids, -1) + self.inverted_cached_idx.index_fill_(0, row_ids, -1) self._cuda_available_row_num += slots.numel() assert self._cuda_available_row_num == self.cuda_row_num @@ -280,25 +279,25 @@ class CachedParamMgr(torch.nn.Module): cpu_row_idxs = torch.unique(cpu_row_idxs_original) assert len(cpu_row_idxs) <= self.cuda_row_num, \ - f"the input indices pull {len(cpu_row_idxs)} chunks, " \ - f"which is larger than the presented {self.cuda_row_num}, " \ - f"please increase cuda_row_num shrink batch size" + f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \ + f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \ + f"Please increase cuda_row_num or decrease the training batch size." self.evict_backlist = cpu_row_idxs - with record_function("(zhg) get cpu chunk indices"): + with record_function("(zhg) get cpu row idxs"): comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)] self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs)) self.num_miss_history.append(len(comm_cpu_row_idxs)) self.num_write_back_history.append(0) - # move sure the cuda chunk will not be evicted! + # move sure the cuda rows will not be evicted! with record_function("(zhg) cache update"): self._prepare_rows_on_cuda(comm_cpu_row_idxs) self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype) - # new ids chunk_offset + offset_in_chunk - with record_function("(zhg) embed idx -> cache chunk id"): + + with record_function("(zhg) embed cpu rows idx -> cache gpu row idxs"): gpu_row_idxs = self._id_to_cached_cuda_id(ids) # update for LFU. @@ -311,17 +310,17 @@ class CachedParamMgr(torch.nn.Module): self._cuda_to_cpu_elapse = 0 self._cuda_to_cpu_numel = 0 - def _chunk_in_cuda(self, chunk_id: int) -> bool: - return self.inverted_cached_idx[chunk_id] != -1 + def _row_in_cuda(self, row_id: int) -> bool: + return self.inverted_cached_idx[row_id] != -1 @torch.no_grad() def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: """prepare rows in cpu_row_idxs on CUDA memory Args: - cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA + cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA """ - evict_num = cpu_row_idxs.numel() - self.cuda_available_chunk_num + evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num if evict_num > 0: with Timer() as timer: mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) @@ -396,7 +395,7 @@ class CachedParamMgr(torch.nn.Module): """ deprecated - evict one chunk from cuda to cpu. + evict one row from cuda to cpu. Returns: (int) : the slot id be evicted. """ diff --git a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py index fc28d95c2..d8ecfb611 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py @@ -119,8 +119,4 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): if self.cache_weight_mgr._cuda_to_cpu_numel > 0: return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \ self.cache_weight_mgr._cuda_to_cpu_elapse - return 0 - - @property - def input_id_percent_in_load_chunk(self): - return 0 # np.mean(self.cache_weight_mgr.input_id_percent_in_load_chunk) * 100 + return 0 \ No newline at end of file