diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index 61b15c8e1..296d18689 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -172,44 +172,53 @@ class CachedParamMgr(torch.nn.Module): ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight. warmup_ratio (float): the amount of chunks preloaded in cuda cache """ - if ids_freq_mapping is not None: - if not isinstance(ids_freq_mapping, torch.Tensor): - ids_freq_mapping = torch.tensor(ids_freq_mapping) - tmp_idx = torch.argsort(ids_freq_mapping, descending=True) - sorted_idx = torch.argsort(tmp_idx) - self.idx_map.data.copy_(sorted_idx) + # reorder phase: reorder the cpu weight according to their freq stats in the target dataset. + # reorder only works for DATASET eviction strategy. + if ids_freq_mapping is not None and not isinstance(ids_freq_mapping, torch.Tensor): + ids_freq_mapping = torch.tensor(ids_freq_mapping) + + if self._evict_strategy == EvictionStrategy.DATASET: + if ids_freq_mapping is not None: + tmp_idx = torch.argsort(ids_freq_mapping, descending=True) + sorted_idx = torch.argsort(tmp_idx) + self.idx_map.data.copy_(sorted_idx) + + # warmup phase: copy #preload_row_num rows from cpu to gpu. preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings) if preload_row_num > 0: with Timer() as timer: # extract rows from cpu weight - preload_row_ids = torch.arange(preload_row_num) - preload_cuda_row_idxs = preload_row_ids.cuda() + if self._evict_strategy == EvictionStrategy.LFU and ids_freq_mapping is not None: + freq_value, preload_cpu_ids = torch.topk(ids_freq_mapping, preload_row_num, dim=0, largest=True) + preload_cuda_row_idxs = torch.arange(preload_row_num).cuda() + else: + preload_cpu_ids = torch.arange(preload_row_num) + preload_cuda_row_idxs = preload_cpu_ids.cuda() if self.buffer_size > 0: self.limit_buff_index_copyer.index_copy(0, - src_index=preload_row_ids, + src_index=preload_cpu_ids, tgt_index=preload_cuda_row_idxs, src=self.weight.view(self.num_embeddings, -1), tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) else: - preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda() + preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_cpu_ids).cuda() self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs, preload_rows) # update auxiliary info - slot_offsets = preload_cuda_row_idxs - self.cached_idx_map[preload_cuda_row_idxs] = preload_cuda_row_idxs + self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.cuda() + self.inverted_cached_idx[preload_cpu_ids] = preload_cuda_row_idxs + self._cuda_available_row_num -= preload_row_num if self._evict_strategy == EvictionStrategy.LFU: # if the ids_freq_mapping is not None, we initialize the embedding row's freq value in LFU as its freq in dataset. if ids_freq_mapping is None: self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0) else: - self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, self.idx_map[preload_cuda_row_idxs]) + self.freq_cnter[preload_cuda_row_idxs] = freq_value.cuda() - self.inverted_cached_idx[preload_cuda_row_idxs] = slot_offsets - self._cuda_available_row_num -= preload_row_num print(f'Cache warmup finished cost {timer.elapsed} sec.') def flush(self): diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py index 3cd010dd9..2f7ee579e 100644 --- a/tests/test_layers/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -144,49 +144,52 @@ def test_freq_aware_embed(use_LFU: bool): assert torch.allclose(model_weight, ref_weight), \ f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}" -def test_lfu_strategy(): - # minimal test to check behavior - Bag = FreqAwareEmbeddingBag( - 5, - 5, - cuda_row_num=3, - buffer_size=0, - pin_weight=True, - warmup_ratio=0.0, - evict_strategy=EvictionStrategy.LFU - ) - offsets = torch.tensor([0],device="cuda:0") +@pytest.mark.parametrize('init_freq', [True, False]) +def test_lfu_strategy(init_freq: bool): + # minimal test to check behavior + Bag = FreqAwareEmbeddingBag(5, + 5, + cuda_row_num=3, + buffer_size=0, + pin_weight=True, + ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None, + warmup_ratio=1.0, + evict_strategy=EvictionStrategy.LFU) + + # print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map) + offsets = torch.tensor([0], device="cuda:0") # prepare frequency learning info: - Bag.forward(torch.tensor([2],device="cuda:0"),offsets) - Bag.forward(torch.tensor([1,2],device="cuda:0"),offsets) - Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets) - Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) - Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) - Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) - Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) - Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets) - Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets) - Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets) - Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets) - Bag.forward(torch.tensor([0],device="cuda:0"),offsets) - Bag.forward(torch.tensor([0],device="cuda:0"),offsets) - Bag.forward(torch.tensor([0],device="cuda:0"),offsets) - Bag.forward(torch.tensor([0],device="cuda:0"),offsets) + Bag.forward(torch.tensor([2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) # check strategy - Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) - Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) - Bag.forward(torch.tensor([3],device="cuda:0"),offsets) # miss, evict 1 - Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit - Bag.forward(torch.tensor([4],device="cuda:0"),offsets) # miss, evict 3 - Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit - Bag.forward(torch.tensor([0],device="cuda:0"),offsets) # hit + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([3], device="cuda:0"), offsets) # miss, evict 1 + Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit + Bag.forward(torch.tensor([4], device="cuda:0"), offsets) # miss, evict 3 + Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) # hit assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \ "LFU strategy behavior failed" - + + def gather_tensor(tensor, rank, world_size): gather_list = [] if rank == 0: @@ -279,4 +282,4 @@ def test_parallel_freq_aware_embed(world_size): if __name__ == '__main__': # test_freq_aware_embed(True) # test_parallel_freq_aware_embed(2) - test_lfu_strategy() \ No newline at end of file + test_lfu_strategy(False)