From 1b8fee8e9c1d74095436363f33f0373fbc9e55c9 Mon Sep 17 00:00:00 2001 From: CsRic <59389055+CsRic@users.noreply.github.com> Date: Mon, 29 Aug 2022 11:44:55 +0800 Subject: [PATCH] [FAW] shrink freq_cnter size (#1509) --- .../layers/cache_embedding/cache_mgr.py | 69 ++++++++----------- tests/test_layers/test_cache_embedding.py | 3 +- 2 files changed, 31 insertions(+), 41 deletions(-) diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index 19fe5d35d..0e6bc4ecd 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -14,6 +14,7 @@ class EvictionStrategy(Enum): DATASET = 2 + class CachedParamMgr(torch.nn.Module): """ Manage Embedding Weights on CPU and CUDA memory uses a software cache. @@ -46,7 +47,6 @@ class CachedParamMgr(torch.nn.Module): self.cuda_row_num = cuda_row_num self._cuda_available_row_num = self.cuda_row_num self.pin_weight = pin_weight - self.elem_size_in_byte = weight.element_size() # weight configure @@ -61,31 +61,13 @@ class CachedParamMgr(torch.nn.Module): self._evict_strategy = evict_strategy if self._evict_strategy == EvictionStrategy.LFU: - # cpu_row_idx -> frequency, freq of the cpu rows. - # evict the minimal freq value row in cuda cache. - ''' - The last element of `freq_cnter` is set to the maximum value of int. - The rows store nothing (not used) in the `self.cuda_weight` whose value is -1 in `self.cached_idx_map`. - In this way, the not used rows are placed at the end of the sorted. - ''' + # cache_row_idx -> frequency, freq of the cache rows. + # classic lfu cache. evict the minimal freq value row in cuda cache. self.register_buffer("freq_cnter", - torch.empty(self.num_embeddings + 1, + torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), - dtype=torch.long).fill_(0), + dtype=torch.long).fill_(sys.maxsize), persistent=False) - self.freq_cnter[-1] = sys.maxsize - - def _update_freq_cnter(self, cpu_row_idxs_original: torch.Tensor) -> None: - """_update_freq_cnter - - Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter. - - Args: - cpu_row_idxs (torch.Tensor): a list of indices of cpu weight. - """ - if self._evict_strategy == EvictionStrategy.LFU: - add_num = torch.bincount(cpu_row_idxs_original) - self.freq_cnter[:add_num.shape[0]] += add_num def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor: """_find_evict_gpu_idxs @@ -100,14 +82,15 @@ class CachedParamMgr(torch.nn.Module): """ if self._evict_strategy == EvictionStrategy.LFU: # find the minimal evict_num freq entries in cached_idx_map - evict_gpu_row_idxs = torch.argsort(self.freq_cnter[self.cached_idx_map])[:evict_num] + _,evict_gpu_row_idxs = torch.topk(self.freq_cnter,evict_num,largest=False) return evict_gpu_row_idxs elif self._evict_strategy == EvictionStrategy.DATASET: # cached_idx_map itself implies the priority of eviction. # The value of self.cached_idx_map represents cpu_row_idx. # The larger it is, the less frequently it will appear in the dataset, # and the higher its eviction priority will be. - return torch.argsort(self.cached_idx_map, descending=True)[:evict_num] + _,evict_gpu_row_idxs = torch.topk(self.cached_idx_map, evict_num, largest=True) + return evict_gpu_row_idxs else: raise TypeError @@ -181,8 +164,7 @@ class CachedParamMgr(torch.nn.Module): Execute only once before training, also known as warmup phase. :NOTE If you would like to use the DATASET as the eviction strategy, you must call this function. - :NOTE If you are use the LFU as the eviction strategy, you can skip this function. The `freq_cnter` will be initialized as all zeros. - You can also call this function to inialized the `freq_cnter` with dataset frequency statistics. + :NOTE If you are use the LFU as the eviction strategy, you can skip this function. Args: ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight. @@ -194,9 +176,6 @@ class CachedParamMgr(torch.nn.Module): tmp_idx = torch.argsort(ids_freq_mapping, descending=True) sorted_idx = torch.argsort(tmp_idx) self.idx_map.data.copy_(sorted_idx) - #initialize freq_cnter if use LFU - if self._evict_strategy == EvictionStrategy.LFU: - self.freq_cnter[:-1], _ = torch.sort(ids_freq_mapping) preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings) if preload_row_num > 0: @@ -218,6 +197,8 @@ class CachedParamMgr(torch.nn.Module): # update auxiliary info slot_offsets = preload_slot_ids self.cached_idx_map[preload_slot_ids] = preload_slot_ids + if self._evict_strategy == EvictionStrategy.LFU : + self.freq_cnter.index_fill_(0,preload_slot_ids,0) self.inverted_cached_idx[preload_slot_ids] = slot_offsets self._cuda_available_row_num -= preload_row_num print(f'Cache warmup finished cost {timer.elapsed} sec.') @@ -234,6 +215,8 @@ class CachedParamMgr(torch.nn.Module): self.inverted_cached_idx.index_fill_(0, row_ids, -1) self._cuda_available_row_num += slots.numel() + if self._evict_strategy == EvictionStrategy.LFU : + self.freq_cnter.fill_(sys.maxsize) assert self._cuda_available_row_num == self.cuda_row_num assert torch.all(self.inverted_cached_idx == -1).item() assert torch.all(self.cached_idx_map == -1).item() @@ -275,8 +258,7 @@ class CachedParamMgr(torch.nn.Module): torch.Tensor: indices on the cuda_cached_weight. """ with record_function("(zhg) get unique indices"): - cpu_row_idxs_original = self.idx_map.index_select(0, ids) - cpu_row_idxs = torch.unique(cpu_row_idxs_original) + cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts = True) assert len(cpu_row_idxs) <= self.cuda_row_num, \ f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \ @@ -301,7 +283,10 @@ class CachedParamMgr(torch.nn.Module): gpu_row_idxs = self._id_to_cached_cuda_id(ids) # update for LFU. - self._update_freq_cnter(cpu_row_idxs_original) + if self._evict_strategy == EvictionStrategy.LFU : + unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs] + self.freq_cnter.scatter_add_(0,unique_gpu_row_idxs,repeat_times) + return gpu_row_idxs def _reset_comm_stats(self): @@ -324,23 +309,21 @@ class CachedParamMgr(torch.nn.Module): if evict_num > 0: with Timer() as timer: mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) - + invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) if self._evict_strategy == EvictionStrategy.DATASET: # mask method. # set cached_idx_map[invalid_idxs] to -2. # so those idxs will be sorted to end, therefore not being chosen as victim - invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone() self.cached_idx_map.index_fill_(0, invalid_idxs, -2) evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) elif self._evict_strategy == EvictionStrategy.LFU: - invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) - backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone() - self.cached_idx_map.index_fill_(0, invalid_idxs, -1) + backup_freqs = self.freq_cnter[invalid_idxs].clone() + self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize) evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) - self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) + self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs) evict_info = self.cached_idx_map[evict_gpu_row_idxs] @@ -357,6 +340,7 @@ class CachedParamMgr(torch.nn.Module): self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1) self.inverted_cached_idx.index_fill_(0, evict_info, -1) + # self.freq_cnter.index_fill(0, evict_gpu_row_idxs, sys.maxsize) # unnecessary self._cuda_available_row_num += evict_num weight_size = evict_gpu_row_idxs.numel() * self.embedding_dim @@ -379,6 +363,8 @@ class CachedParamMgr(torch.nn.Module): slot_offsets = slots self.cached_idx_map[slots] = cpu_row_idxs self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets) + if self._evict_strategy == EvictionStrategy.LFU : + self.freq_cnter.index_fill_(0, slots, 0) self._cuda_available_row_num -= cpu_row_idxs.numel() self._cpu_to_cuda_elpase += timer.elapsed weight_size = cpu_row_idxs.numel() * self.embedding_dim @@ -421,7 +407,8 @@ class CachedParamMgr(torch.nn.Module): # update inverted_cached_idx, min_slot_id is evicted from cuda self.cached_idx_map[max_cpu_row_idx] = -1 - + if self._evict_strategy == EvictionStrategy.LFU : + self.freq_cnter[max_cpu_row_idx] = sys.maxsize self.inverted_cached_idx[max_gpu_row_idx] = -1 self._cuda_available_row_num += 1 @@ -456,6 +443,8 @@ class CachedParamMgr(torch.nn.Module): # update the inverted_cached_idx self.cached_idx_map[slot_id] = row_id + if self._evict_strategy == EvictionStrategy.LFU : + self.freq_cnter[slot_id] = 0 self.inverted_cached_idx[row_id] = slot_offset self._cuda_available_row_num -= 1 diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py index f990174fb..3cd010dd9 100644 --- a/tests/test_layers/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -177,9 +177,10 @@ def test_lfu_strategy(): # check strategy Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) + Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) Bag.forward(torch.tensor([3],device="cuda:0"),offsets) # miss, evict 1 Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit - Bag.forward(torch.tensor([4],device="cuda:0"),offsets) # miss, evict 1 + Bag.forward(torch.tensor([4],device="cuda:0"),offsets) # miss, evict 3 Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit Bag.forward(torch.tensor([0],device="cuda:0"),offsets) # hit