diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index 42f3e0e4b..ef20cfc79 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -293,7 +293,7 @@ class CachedParamMgr(torch.nn.Module): Returns: torch.Tensor: indices on the cuda_cached_weight. """ - with record_function("(pre-id) get unique indices"): + with record_function(f"(pre-id) get unique indices. cache ratio {self.cuda_row_num / self.num_embeddings}"): ids = ids.to(self._cache_dev) cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True) diff --git a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py index 282f6d0c4..c816f85b9 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py @@ -27,10 +27,10 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): include_last_offset (bool, optional): if True, offsets has one additional element, where the last element is equivalent to the size of indices. This matches the CSR format.. Defaults to False. dtype (torch.dtype, optional): data type of the cpu weight initialization. Defaults to None meaning float32. device (torch.device, optional): device type to the cpu weight. Defaults to None meaning cpu. - cuda_row_num (int, optional): the max number of embedding vector in cuda cache. Defaults to 0. + cache_ratio (float, float): cache ratio of the #cuda_weight_row / #cpu_weight_row ids_freq_mapping (Union[List, torch.Tensor], optional): the frequency of each embedding vector occures in dataset. Defaults to None. warmup_ratio (float, optional): the ratio of cuda cache is warmuped with. Defaults to 0.7. - buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, means do not use the buffer. Defaults to 0. + buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, the buffer is not used. Defaults to 0. pin_weight (bool, optional): pin the cpu weight. Defaults to False. evict_strategy (EvictionStrategy, optional): evict strategy of the software cache. Defaults to EvictionStrategy.DATASET. """ @@ -48,7 +48,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): include_last_offset: bool = False, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, - cuda_row_num: int = 0, + cache_ratio: float = 0.01, ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None, warmup_ratio: float = 0.7, buffer_size: int = 0, @@ -57,10 +57,11 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, mode, include_last_offset) + assert cache_ratio <= 1.0, f"cache ratio {cache_ratio} must less than 1.0" self.evict_strategy = evict_strategy if _weight is None: _weight = self._weight_alloc(dtype, device) - + cuda_row_num = int(num_embeddings * cache_ratio) # configure weight & cache self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight) diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py index 61d870fad..1171496fc 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py @@ -43,7 +43,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): include_last_offset=False, dtype=None, device=None, - cuda_row_num=0, + cache_ratio=0.01, ids_freq_mapping=None, warmup_ratio=0.7, buffer_size=50_000, @@ -58,7 +58,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): super(ParallelFreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, - sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping, + sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight, evict_strategy) def _weight_alloc(self, dtype, device): diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py index 731115d3c..ab8a232b1 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py @@ -31,7 +31,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): include_last_offset=False, dtype=None, device=None, - cuda_row_num=0, + cache_ratio=0.01, warmup_ratio=0.7, buffer_size=50_000, pin_weight=False, @@ -59,11 +59,12 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): else: ids_freq_mapping = None break - + self.cache_ratio = cache_ratio # table-associate cache + cuda_row_num = int(cache_ratio * self.num_embeddings) super(ParallelFreqAwareEmbeddingBagTablewise, self).__init__(self.num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, - sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping, + sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight, evict_strategy) # for assigned tables reconnection: diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py index 928fbef9c..5bb654217 100644 --- a/tests/test_layers/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -110,7 +110,7 @@ def test_freq_aware_embed(use_LFU: bool): EMBED_DIM, mode='mean', include_last_offset=True, - cuda_row_num=BATCH_SIZE * 2, + cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0), ids_freq_mapping=None, evict_strategy=evict_strategy).to(device) @@ -153,7 +153,7 @@ def test_lfu_strategy(init_freq: bool): # minimal test to check behavior Bag = FreqAwareEmbeddingBag(5, 5, - cuda_row_num=3, + cache_ratio=3 / 5, buffer_size=0, pin_weight=True, ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None, @@ -238,7 +238,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): embedding_dim=5, _weight=_weight, include_last_offset=True, - cuda_row_num=8, + cache_ratio=0.5, buffer_size=0, evict_strategy=EvictionStrategy.LFU, ) @@ -304,7 +304,7 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size): coloweight, include_last_offset=True, freeze=False, - cuda_row_num=batch_size * 2, + cache_ratio=batch_size * 2 / num_embed, ) assert model.cache_weight_mgr.weight.device.type == 'cpu'