diff --git a/colossalai/nn/parallel/layers/__init__.py b/colossalai/nn/parallel/layers/__init__.py index 0ebadac6c..1847e0e05 100644 --- a/colossalai/nn/parallel/layers/__init__.py +++ b/colossalai/nn/parallel/layers/__init__.py @@ -3,10 +3,10 @@ from .linear import ColoLinear from .embedding import ColoEmbedding from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module -from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer +from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy __all__ = [ 'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', 'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr', - 'LimitBuffIndexCopyer' + 'LimitBuffIndexCopyer', 'EvictionStrategy' ] diff --git a/colossalai/nn/parallel/layers/cache_embedding/__init__.py b/colossalai/nn/parallel/layers/cache_embedding/__init__.py index 10dbe1c8a..e3644dc9c 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/__init__.py +++ b/colossalai/nn/parallel/layers/cache_embedding/__init__.py @@ -1,6 +1,9 @@ -from .cache_mgr import CachedParamMgr +from .cache_mgr import CachedParamMgr, EvictionStrategy from .copyer import LimitBuffIndexCopyer from .freq_aware_embedding import FreqAwareEmbeddingBag from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag -__all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag'] +__all__ = [ + 'CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', + 'EvictionStrategy' +] diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index fbe24caca..83a51b757 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -4,6 +4,12 @@ from torch.profiler import record_function from typing import List, Optional from contexttimer import Timer from .copyer import LimitBuffIndexCopyer +from enum import Enum + + +class EvictionStrategy(Enum): + LFU = 1 + DATASET = 2 class CachedParamMgr(torch.nn.Module): @@ -18,7 +24,8 @@ class CachedParamMgr(torch.nn.Module): weight: torch.Tensor, cuda_row_num: int = 0, buffer_size: int = 50_000, - pin_weight=False) -> None: + pin_weight=False, + evict_strategy=EvictionStrategy.DATASET) -> None: super(CachedParamMgr, self).__init__() self.buffer_size = buffer_size self.num_embeddings, self.embedding_dim = weight.shape @@ -38,6 +45,51 @@ class CachedParamMgr(torch.nn.Module): self.input_id_percent_in_load_chunk = [] self._reset_comm_stats() + self._evict_strategy = evict_strategy + + if self._evict_strategy == EvictionStrategy.LFU: + # cpu_row_idx -> frequency, freq of the cpu rows. + # evict the minimal freq value row in cuda cache. + self.register_buffer("freq_cnter", + torch.empty(self.num_embeddings, device=torch.cuda.current_device(), + dtype=torch.long).fill_(0), + persistent=False) + + def _update_freq_cnter(self, cpu_row_idxs: torch.Tensor) -> None: + """_update_freq_cnter + + Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter. + + Args: + cpu_row_idxs (torch.Tensor): a list of indices of cpu weight. + """ + if self._evict_strategy == EvictionStrategy.LFU: + self.freq_cnter[cpu_row_idxs] += 1 + + def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor: + """_find_evict_gpu_idxs + + Find the gpu idxs to be evicted, according to their freq. + + Args: + evict_num (int): how many rows has to be evicted + + Returns: + torch.Tensor: a list tensor (1D), contains the gpu_row_idxs. + """ + if self._evict_strategy == EvictionStrategy.LFU: + # find the minimal evict_num freq entries in cached_idx_map + evict_gpu_row_idxs = torch.argsort(self.freq_cnter[self.cached_idx_map])[:evict_num] + return self.cached_idx_map[evict_gpu_row_idxs] + elif self._evict_strategy == EvictionStrategy.DATASET: + # cached_idx_map itself implies the priority of eviction. + # The value of self.cached_idx_map represents cpu_row_idx. + # The larger it is, the less frequently it will appear in the dataset, + # and the higher its eviction priority will be. + return torch.argsort(self.cached_idx_map, descending=True)[:evict_num] + else: + raise TypeError + def _init_weight(self, weight): if self.cuda_row_num > 0: # Enable cache with introducing auxiliary data structures @@ -220,6 +272,10 @@ class CachedParamMgr(torch.nn.Module): # new ids chunk_offset + offset_in_chunk with record_function("(zhg) embed idx -> cache chunk id"): gpu_row_idxs = self._id_to_cached_cuda_id(ids) + + # update for LFU. + self._update_freq_cnter(cpu_row_idxs) + return gpu_row_idxs def _reset_comm_stats(self): @@ -234,6 +290,7 @@ class CachedParamMgr(torch.nn.Module): @torch.no_grad() def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: """prepare rows in cpu_row_idxs on CUDA memory + Args: cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA """ @@ -245,7 +302,9 @@ class CachedParamMgr(torch.nn.Module): invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) self.cached_idx_map.index_fill_(0, invalid_idxs, -2) - evict_gpu_row_idxs = torch.argsort(self.cached_idx_map, descending=True)[:evict_num] + + evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) + self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) evict_info = self.cached_idx_map[evict_gpu_row_idxs] @@ -291,8 +350,16 @@ class CachedParamMgr(torch.nn.Module): self._cpu_to_cuda_numel += weight_size # print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB") + def _find_free_cuda_row(self) -> int: + if self._cuda_available_row_num == 0: + return -1 + candidates = torch.nonzero(self.cached_idx_map == -1).squeeze(1) + return candidates[0].item() + def _evict(self) -> int: """ + deprecated + evict one chunk from cuda to cpu. Returns: (int) : the slot id be evicted. @@ -329,15 +396,11 @@ class CachedParamMgr(torch.nn.Module): # self.num_write_back_history[-1] += 1 return max_cpu_row_idx - def _find_free_cuda_row(self) -> int: - if self._cuda_available_row_num == 0: - return -1 - candidates = torch.nonzero(self.cached_idx_map == -1).squeeze(1) - return candidates[0].item() - @torch.no_grad() def _admit(self, row_id: int): """ + deprecated + move in row_id to CUDA Args: diff --git a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py index ecf890cf0..fc28d95c2 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py @@ -3,35 +3,35 @@ import torch.nn.functional as F from typing import List, Optional, Iterator, Tuple from .base_embedding import BaseEmbeddingBag -from .cache_mgr import CachedParamMgr +from .cache_mgr import CachedParamMgr, EvictionStrategy from torch.nn.parameter import Parameter class FreqAwareEmbeddingBag(BaseEmbeddingBag): - def __init__( - self, - num_embeddings, - embedding_dim, - padding_idx=None, - max_norm=None, - norm_type=2., - scale_grad_by_freq=False, - sparse=False, - _weight=None, - mode='mean', - include_last_offset=False, - dtype=None, - device=None, - cuda_row_num=0, - ids_freq_mapping=None, - warmup_ratio=0.7, - buffer_size=50_000, - pin_weight=False, - ): + def __init__(self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2., + scale_grad_by_freq=False, + sparse=False, + _weight=None, + mode='mean', + include_last_offset=False, + dtype=None, + device=None, + cuda_row_num=0, + ids_freq_mapping=None, + warmup_ratio=0.7, + buffer_size=50_000, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.DATASET): super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, mode, include_last_offset) + self.evict_strategy = evict_strategy if _weight is None: _weight = self._weight_alloc(dtype, device) @@ -63,7 +63,11 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): ids_freq_mapping (List[int]): a list, idx is id number, value is freq warmup_ratio (float): the amount of rows preloaded in cuda cache """ - self.cache_weight_mgr = CachedParamMgr(weight, cuda_row_num, buffer_size, pin_weight) + self.cache_weight_mgr = CachedParamMgr(weight, + cuda_row_num, + buffer_size, + pin_weight, + evict_strategy=self.evict_strategy) self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None): diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py index f238e51e8..71c22e243 100644 --- a/tests/test_layers/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -12,7 +12,7 @@ from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \ ColoTensor, ColoTensorSpec -from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag +from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy NUM_EMBED, EMBED_DIM = 10, 8 BATCH_SIZE = 8 @@ -41,6 +41,7 @@ def synthesize_1d_sparse_feature( return indices, offsets +@pytest.mark.skip def test_cachemgr(): model = torch.nn.EmbeddingBag(10000, 128) # 10 chunks, 5 in cuda @@ -98,14 +99,17 @@ def test_reorder_with_freq(): f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}" -def test_freq_aware_embed(): +@pytest.mark.parametrize('use_LFU', [True, False]) +def test_freq_aware_embed(use_LFU: bool): device = torch.device('cuda', 0) + evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET model = FreqAwareEmbeddingBag(NUM_EMBED, EMBED_DIM, mode='mean', include_last_offset=True, cuda_row_num=BATCH_SIZE * 2, - ids_freq_mapping=None).to(device) + ids_freq_mapping=None, + evict_strategy=evict_strategy).to(device) assert model.weight.shape[0] == NUM_EMBED ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device), @@ -231,6 +235,5 @@ def test_parallel_freq_aware_embed(world_size): if __name__ == '__main__': - test_cachemgr() - # test_freq_aware_embed() + test_freq_aware_embed(True) # test_parallel_freq_aware_embed(2)