diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index 83a51b757..28828b36d 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -59,7 +59,7 @@ class CachedParamMgr(torch.nn.Module): """_update_freq_cnter Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter. - + Args: cpu_row_idxs (torch.Tensor): a list of indices of cpu weight. """ @@ -80,7 +80,7 @@ class CachedParamMgr(torch.nn.Module): if self._evict_strategy == EvictionStrategy.LFU: # find the minimal evict_num freq entries in cached_idx_map evict_gpu_row_idxs = torch.argsort(self.freq_cnter[self.cached_idx_map])[:evict_num] - return self.cached_idx_map[evict_gpu_row_idxs] + return evict_gpu_row_idxs elif self._evict_strategy == EvictionStrategy.DATASET: # cached_idx_map itself implies the priority of eviction. # The value of self.cached_idx_map represents cpu_row_idx. @@ -298,15 +298,27 @@ class CachedParamMgr(torch.nn.Module): if evict_num > 0: with Timer() as timer: mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) - backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone() + invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) - self.cached_idx_map.index_fill_(0, invalid_idxs, -2) - - evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) - - self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) - + if self._evict_strategy == EvictionStrategy.DATASET: + # mask method. + # set cached_idx_map[invalid_idxs] to -2. + # so those idxs will be sorted to end, therefore not being chosen as victim + backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone() + self.cached_idx_map.index_fill_(0, invalid_idxs, -2) + evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) + self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) + + elif self._evict_strategy == EvictionStrategy.LFU: + # another mask method. + # set freq_cnter[invalid_idxs] to max + # so those idxs will be sorted to end, therefore not being chosen as victim + backup_cnter = self.freq_cnter[invalid_idxs].clone() + self.freq_cnter.index_fill_(0, invalid_idxs, torch.max(self.freq_cnter) + 1) # or can we use a confident max value? + evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) + self.freq_cnter.index_copy_(0,invalid_idxs,backup_cnter) + evict_info = self.cached_idx_map[evict_gpu_row_idxs] if self.buffer_size > 0: diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py index 71c22e243..99caf1407 100644 --- a/tests/test_layers/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -144,6 +144,44 @@ def test_freq_aware_embed(use_LFU: bool): assert torch.allclose(model_weight, ref_weight), \ f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}" +def test_lfu_strategy(): + # minimal test to check behavior + Bag = FreqAwareEmbeddingBag( + 5, + 5, + cuda_row_num=3, + buffer_size=0, + pin_weight=True, + warmup_ratio=0.0, + evict_strategy=EvictionStrategy.LFU + ) + + offsets = torch.tensor([0],device="cuda:0") + + # prepare frequency learning info: + Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) + Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) + Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) + Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) + Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets) + Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets) + Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets) + Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets) + Bag.forward(torch.tensor([0],device="cuda:0"),offsets) + Bag.forward(torch.tensor([0],device="cuda:0"),offsets) + Bag.forward(torch.tensor([0],device="cuda:0"),offsets) + Bag.forward(torch.tensor([0],device="cuda:0"),offsets) + + # check strategy + Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) + Bag.forward(torch.tensor([3],device="cuda:0"),offsets) # miss, evict 1 + Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit + Bag.forward(torch.tensor([4],device="cuda:0"),offsets) # miss, evict 1 + Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit + Bag.forward(torch.tensor([0],device="cuda:0"),offsets) # hit + + assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \ + "LFU strategy behavior failed" def gather_tensor(tensor, rank, world_size): gather_list = [] @@ -237,3 +275,4 @@ def test_parallel_freq_aware_embed(world_size): if __name__ == '__main__': test_freq_aware_embed(True) # test_parallel_freq_aware_embed(2) + # test_lfu_strategy() \ No newline at end of file