diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index 28828b36d..a6ba188fd 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -5,7 +5,7 @@ from typing import List, Optional from contexttimer import Timer from .copyer import LimitBuffIndexCopyer from enum import Enum - +import sys class EvictionStrategy(Enum): LFU = 1 @@ -25,14 +25,14 @@ class CachedParamMgr(torch.nn.Module): cuda_row_num: int = 0, buffer_size: int = 50_000, pin_weight=False, - evict_strategy=EvictionStrategy.DATASET) -> None: + evict_strategy=EvictionStrategy.DATASET,) -> None: super(CachedParamMgr, self).__init__() self.buffer_size = buffer_size self.num_embeddings, self.embedding_dim = weight.shape self.cuda_row_num = cuda_row_num self._cuda_available_row_num = self.cuda_row_num self.pin_weight = pin_weight - + self.elem_size_in_byte = weight.element_size() # weight configure @@ -50,12 +50,22 @@ class CachedParamMgr(torch.nn.Module): if self._evict_strategy == EvictionStrategy.LFU: # cpu_row_idx -> frequency, freq of the cpu rows. # evict the minimal freq value row in cuda cache. + + ''' + during cache eviction, if a cached_idx_map element maps to a masked cpu_idx, we re-map that element to -1 temporary. + also, disabled cached_idx_map element maps to -1 by default. + freq_cnter[-1], the last element, should ALWAYS be MAX VALUE so those masked or disabled idxs will be argsorted to end, + not being chosen to evict. + + ZH: freq_cnter的最后一位设为了最大值, 不该被选为换出的cache idx都是-1, 指向这个最大值, 所以排序时在队尾, 不会被选中换出 + ''' self.register_buffer("freq_cnter", - torch.empty(self.num_embeddings, device=torch.cuda.current_device(), + torch.empty(self.num_embeddings + 1, device=torch.cuda.current_device(), dtype=torch.long).fill_(0), persistent=False) + self.freq_cnter[-1] = sys.maxsize - def _update_freq_cnter(self, cpu_row_idxs: torch.Tensor) -> None: + def _update_freq_cnter(self, cpu_row_idxs_original: torch.Tensor) -> None: """_update_freq_cnter Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter. @@ -64,7 +74,8 @@ class CachedParamMgr(torch.nn.Module): cpu_row_idxs (torch.Tensor): a list of indices of cpu weight. """ if self._evict_strategy == EvictionStrategy.LFU: - self.freq_cnter[cpu_row_idxs] += 1 + add_num = torch.bincount(cpu_row_idxs_original) + self.freq_cnter[:add_num.shape[0]] += add_num def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor: """_find_evict_gpu_idxs @@ -165,10 +176,13 @@ class CachedParamMgr(torch.nn.Module): warmup_ratio (float): the amount of chunks preloaded in cuda cache """ if ids_freq_mapping is not None: + ids_freq_mapping = torch.tensor(ids_freq_mapping) tmp_idx = torch.argsort(ids_freq_mapping, descending=True) sorted_idx = torch.argsort(tmp_idx) self.idx_map.data.copy_(sorted_idx) - + #initialize freq_cnter if use LFU + if self._evict_strategy == EvictionStrategy.LFU: + self.freq_cnter[:-1],_ = torch.sort(ids_freq_mapping) # TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks. # As cuda_cached_weight is very big. You may not have that much available memory! # Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda @@ -249,8 +263,9 @@ class CachedParamMgr(torch.nn.Module): torch.Tensor: indices on the cuda_cached_weight. """ with record_function("(zhg) get unique indices"): - cpu_row_idxs = torch.unique(self.idx_map.index_select(0, ids)) - + cpu_row_idxs_original = self.idx_map.index_select(0, ids) + cpu_row_idxs = torch.unique(cpu_row_idxs_original) + assert len(cpu_row_idxs) <= self.cuda_row_num, \ f"the input indices pull {len(cpu_row_idxs)} chunks, " \ f"which is larger than the presented {self.cuda_row_num}, " \ @@ -272,10 +287,9 @@ class CachedParamMgr(torch.nn.Module): # new ids chunk_offset + offset_in_chunk with record_function("(zhg) embed idx -> cache chunk id"): gpu_row_idxs = self._id_to_cached_cuda_id(ids) - + # update for LFU. - self._update_freq_cnter(cpu_row_idxs) - + self._update_freq_cnter(cpu_row_idxs_original) return gpu_row_idxs def _reset_comm_stats(self): @@ -298,26 +312,23 @@ class CachedParamMgr(torch.nn.Module): if evict_num > 0: with Timer() as timer: mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) - - invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) if self._evict_strategy == EvictionStrategy.DATASET: # mask method. # set cached_idx_map[invalid_idxs] to -2. # so those idxs will be sorted to end, therefore not being chosen as victim + invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone() self.cached_idx_map.index_fill_(0, invalid_idxs, -2) evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) elif self._evict_strategy == EvictionStrategy.LFU: - # another mask method. - # set freq_cnter[invalid_idxs] to max - # so those idxs will be sorted to end, therefore not being chosen as victim - backup_cnter = self.freq_cnter[invalid_idxs].clone() - self.freq_cnter.index_fill_(0, invalid_idxs, torch.max(self.freq_cnter) + 1) # or can we use a confident max value? + invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) + backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone() + self.cached_idx_map.index_fill_(0, invalid_idxs, -1) evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) - self.freq_cnter.index_copy_(0,invalid_idxs,backup_cnter) + self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) evict_info = self.cached_idx_map[evict_gpu_row_idxs] diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py index 62f9df37f..39f55d37a 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py @@ -6,7 +6,7 @@ from .freq_aware_embedding import FreqAwareEmbeddingBag from colossalai.nn._ops._utils import dual_all_to_all from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor - +from .cache_mgr import CachedParamMgr, EvictionStrategy def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: if world_size == 1: @@ -48,6 +48,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): warmup_ratio=0.7, buffer_size=50_000, pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.DATASET ): self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() @@ -59,7 +60,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): super(ParallelFreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping, - warmup_ratio, buffer_size, pin_weight) + warmup_ratio, buffer_size, pin_weight,evict_strategy) def _weight_alloc(self, dtype, device): weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype) diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py index 99caf1407..f990174fb 100644 --- a/tests/test_layers/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -159,6 +159,9 @@ def test_lfu_strategy(): offsets = torch.tensor([0],device="cuda:0") # prepare frequency learning info: + Bag.forward(torch.tensor([2],device="cuda:0"),offsets) + Bag.forward(torch.tensor([1,2],device="cuda:0"),offsets) + Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets) Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) @@ -182,7 +185,7 @@ def test_lfu_strategy(): assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \ "LFU strategy behavior failed" - + def gather_tensor(tensor, rank, world_size): gather_list = [] if rank == 0: @@ -273,6 +276,6 @@ def test_parallel_freq_aware_embed(world_size): if __name__ == '__main__': - test_freq_aware_embed(True) + # test_freq_aware_embed(True) # test_parallel_freq_aware_embed(2) - # test_lfu_strategy() \ No newline at end of file + test_lfu_strategy() \ No newline at end of file