From 21962e1593e1463492a7b165cde86a3a0b35635b Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 13 Oct 2022 22:22:27 +0800 Subject: [PATCH] [embedding] rename FreqAwareEmbedding -> CachedEmbedding (#1699) --- colossalai/nn/parallel/layers/__init__.py | 10 +-- .../layers/cache_embedding/__init__.py | 14 ++--- .../layers/cache_embedding/cache_mgr.py | 3 +- ...aware_embedding.py => cached_embedding.py} | 10 +-- ...edding.py => parallel_cached_embedding.py} | 8 +-- ...=> parallel_cached_embedding_tablewise.py} | 8 +-- ...cached_embedding_tablewise_split_cache.py} | 62 +++++++++---------- tests/test_layers/test_cache_embedding.py | 38 ++++++------ 8 files changed, 77 insertions(+), 76 deletions(-) rename colossalai/nn/parallel/layers/cache_embedding/{freq_aware_embedding.py => cached_embedding.py} (94%) rename colossalai/nn/parallel/layers/cache_embedding/{parallel_freq_aware_embedding.py => parallel_cached_embedding.py} (96%) rename colossalai/nn/parallel/layers/cache_embedding/{parallel_freq_aware_embedding_tablewise.py => parallel_cached_embedding_tablewise.py} (97%) rename colossalai/nn/parallel/layers/cache_embedding/{parallel_freq_aware_embedding_tablewise_split_cache.py => parallel_cached_embedding_tablewise_split_cache.py} (72%) diff --git a/colossalai/nn/parallel/layers/__init__.py b/colossalai/nn/parallel/layers/__init__.py index 9e1777fa4..29b8353e6 100644 --- a/colossalai/nn/parallel/layers/__init__.py +++ b/colossalai/nn/parallel/layers/__init__.py @@ -3,12 +3,12 @@ from .linear import ColoLinear from .embedding import ColoEmbedding from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module -from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy, \ - ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelFreqAwareEmbeddingBagTablewiseSpiltCache +from .cache_embedding import CachedEmbeddingBag, ParallelCachedEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy, \ + ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelCachedEmbeddingBagTablewiseSpiltCache __all__ = [ 'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', - 'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr', - 'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', - 'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache' + 'ColoLinear', 'ColoEmbedding', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'CachedParamMgr', + 'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', + 'ParallelCachedEmbeddingBagTablewiseSpiltCache' ] diff --git a/colossalai/nn/parallel/layers/cache_embedding/__init__.py b/colossalai/nn/parallel/layers/cache_embedding/__init__.py index 7f1c72588..5bbc931a7 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/__init__.py +++ b/colossalai/nn/parallel/layers/cache_embedding/__init__.py @@ -1,13 +1,13 @@ from .cache_mgr import CachedParamMgr, EvictionStrategy from .copyer import LimitBuffIndexCopyer -from .freq_aware_embedding import FreqAwareEmbeddingBag -from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag +from .cached_embedding import CachedEmbeddingBag +from .parallel_cached_embedding import ParallelCachedEmbeddingBag from .embedding_config import TablewiseEmbeddingBagConfig -from .parallel_freq_aware_embedding_tablewise import ParallelFreqAwareEmbeddingBagTablewise -from .parallel_freq_aware_embedding_tablewise_split_cache import ParallelFreqAwareEmbeddingBagTablewiseSpiltCache +from .parallel_cached_embedding_tablewise import ParallelCachedEmbeddingBagTablewise +from .parallel_cached_embedding_tablewise_split_cache import ParallelCachedEmbeddingBagTablewiseSpiltCache __all__ = [ - 'CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', - 'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', - 'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache' + 'CachedParamMgr', 'LimitBuffIndexCopyer', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'EvictionStrategy', + 'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', + 'ParallelCachedEmbeddingBagTablewiseSpiltCache' ] diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index 725daff40..da043df36 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -352,7 +352,8 @@ class CachedParamMgr(torch.nn.Module): # move sure the cuda rows will not be evicted! with record_function("(cache) prepare_rows_on_cuda"): - self._prepare_rows_on_cuda(comm_cpu_row_idxs) + with self.timer("prepare_rows_on_cuda") as timer: + self._prepare_rows_on_cuda(comm_cpu_row_idxs) self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype) diff --git a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py similarity index 94% rename from colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py rename to colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py index f4704e09e..a0c45d8e8 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py @@ -7,10 +7,10 @@ from .cache_mgr import CachedParamMgr, EvictionStrategy from torch.nn.parameter import Parameter -class FreqAwareEmbeddingBag(BaseEmbeddingBag): - """FreqAwareEmbeddingBag +class CachedEmbeddingBag(BaseEmbeddingBag): + """CachedEmbeddingBag - Frequency Aware Embedding. Apply a GPU-based software cache approaches to dynamically manage the embedding table in the CPU and GPU memory space. + Cached Embedding. Apply a GPU-based software cache approaches to dynamically manage the embedding table in the CPU and GPU memory space. It can leverage the id's frequency statistics of the target dataset, by passing a frequency list to param `ids_freq_mapping`. You can also apply a navie LFU cache eviction strategy by setting `evict_strategy` as EvictionStrategy.LFU. @@ -54,8 +54,8 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): buffer_size: int = 0, pin_weight: bool = False, evict_strategy: EvictionStrategy = EvictionStrategy.LFU): - super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, - scale_grad_by_freq, sparse, mode, include_last_offset) + super(CachedEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, + scale_grad_by_freq, sparse, mode, include_last_offset) assert cache_ratio <= 1.0, f"cache ratio {cache_ratio} must less than 1.0" self.evict_strategy = evict_strategy diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py similarity index 96% rename from colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py rename to colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py index f64917b45..d7f77e195 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py @@ -2,7 +2,7 @@ import torch import torch.nn.functional as F from typing import List, Optional, Iterator, Tuple -from .freq_aware_embedding import FreqAwareEmbeddingBag +from .cached_embedding import CachedEmbeddingBag from colossalai.nn._ops._utils import dual_all_to_all from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor @@ -28,7 +28,7 @@ def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: return offset, offset + size_list[rank], False -class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): +class ParallelCachedEmbeddingBag(CachedEmbeddingBag): def __init__(self, num_embeddings, @@ -56,7 +56,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): embedding_dim, self.rank, self.world_size) self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index - super(ParallelFreqAwareEmbeddingBag, + super(ParallelCachedEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight, evict_strategy) @@ -115,7 +115,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): ids_freq_mapping: Optional[List[int]] = None, warmup_ratio: float = 0.7, buffer_size: int = 0, - ) -> 'ParallelFreqAwareEmbeddingBag': + ) -> 'ParallelCachedEmbeddingBag': rows, cols = embedding.shape embedding_bag = cls(rows, cols, diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py similarity index 97% rename from colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py rename to colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py index 60803b928..949f85ad4 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py @@ -2,7 +2,7 @@ import torch import torch.distributed as dist import torch.nn.functional as F -from .freq_aware_embedding import FreqAwareEmbeddingBag +from .cached_embedding import CachedEmbeddingBag from .cache_mgr import EvictionStrategy from .embedding_config import TablewiseEmbeddingBagConfig from colossalai.tensor import ProcessGroup @@ -12,9 +12,9 @@ from typing import List import time -class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): +class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag): """ - all tables assigned to this class instance are managed by a single FreqAwareEmbeddingBag. + all tables assigned to this class instance are managed by a single CachedEmbeddingBag. Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight. """ @@ -62,7 +62,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): self.cache_ratio = cache_ratio # table-associate cache cuda_row_num = int(cache_ratio * self.num_embeddings) - super(ParallelFreqAwareEmbeddingBagTablewise, + super(ParallelCachedEmbeddingBagTablewise, self).__init__(self.num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight, evict_strategy) diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise_split_cache.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py similarity index 72% rename from colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise_split_cache.py rename to colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py index 807ab389a..cb4647028 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise_split_cache.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py @@ -3,7 +3,7 @@ import torch.distributed as dist import torch.nn as nn from torch.profiler import record_function -from .freq_aware_embedding import FreqAwareEmbeddingBag +from .cached_embedding import CachedEmbeddingBag from colossalai.tensor import ProcessGroup from colossalai.nn._ops._utils import dual_all_to_all_tablewise @@ -14,9 +14,9 @@ from typing import List import abc -class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module): +class ParallelCachedEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module): """ - every table assigned to this class instance is managed by a FreqAwareEmbeddingBag. + every table assigned to this class instance is managed by a CachedEmbeddingBag. """ def __init__(self, @@ -34,7 +34,7 @@ class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module): warmup_ratio=0.7, pin_weight=False, evict_strategy: EvictionStrategy = EvictionStrategy.LFU): - super(ParallelFreqAwareEmbeddingBagTablewiseSpiltCache, self).__init__() + super(ParallelCachedEmbeddingBagTablewiseSpiltCache, self).__init__() self.rank = dist.get_rank() self.world_size = dist.get_world_size() self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list] @@ -49,31 +49,31 @@ class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module): self.include_last_offset = include_last_offset self.pg = ProcessGroup(tp_degree=self.world_size) - # prepare FreqAwareEmbeddingBag list + # prepare CachedEmbeddingBag list - self.freq_aware_embedding_bag_list: nn.ModuleList = nn.ModuleList() + self.cached_embedding_bag_list: nn.ModuleList = nn.ModuleList() for config in embedding_bag_config_list: if config.assigned_rank != self.rank: continue - self.freq_aware_embedding_bag_list.append( - FreqAwareEmbeddingBag(num_embeddings=config.num_embeddings, - embedding_dim=embedding_dim, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - _weight=config.initial_weight, - mode=mode, - include_last_offset=include_last_offset, - dtype=dtype, - device=device, - cuda_row_num=config.cuda_row_num, - ids_freq_mapping=config.ids_freq_mapping, - warmup_ratio=warmup_ratio, - buffer_size=config.buffer_size, - pin_weight=pin_weight, - evict_strategy=evict_strategy)) + self.cached_embedding_bag_list.append( + CachedEmbeddingBag(num_embeddings=config.num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=config.initial_weight, + mode=mode, + include_last_offset=include_last_offset, + dtype=dtype, + device=device, + cuda_row_num=config.cuda_row_num, + ids_freq_mapping=config.ids_freq_mapping, + warmup_ratio=warmup_ratio, + buffer_size=config.buffer_size, + pin_weight=pin_weight, + evict_strategy=evict_strategy)) # prepare list shape for all_to_all output self.embedding_dim_per_rank = [0 for i in range(self.world_size)] @@ -109,8 +109,8 @@ class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module): if per_sample_weights != None: local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position] with record_function("(tablewise) tablewise forward"): - local_output_list.append(self.freq_aware_embedding_bag_list[i](local_indices, local_offsets, - local_per_sample_weights)) + local_output_list.append(self.cached_embedding_bag_list[i](local_indices, local_offsets, + local_per_sample_weights)) # get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim)) local_output = torch.cat(local_output_list, 1) @@ -126,13 +126,13 @@ class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module): def element_size(self): if len(self.assigned_table_list) == 0: return 0 - return self.freq_aware_embedding_bag_list[0].cache_weight_mgr.weight.element_size() + return self.cached_embedding_bag_list[0].cache_weight_mgr.weight.element_size() def print_comm_stats_(self): cuda_to_cpu_elem_num = 0 cpu_to_cuda_elem_num = 0 - for freq_aware_embedding_bag in self.freq_aware_embedding_bag_list: - cuda_to_cpu_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel - cpu_to_cuda_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel + for cached_embedding_bag in self.cached_embedding_bag_list: + cuda_to_cpu_elem_num += cached_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel + cpu_to_cuda_elem_num += cached_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel print(f"CUDA->CPU num: {cuda_to_cpu_elem_num / 1e6} M elem") print(f"CPU->CUDA num: {cpu_to_cuda_elem_num / 1e6} M elem") diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py index 039301a7e..cff9072c7 100644 --- a/tests/test_layers/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -12,8 +12,8 @@ from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \ ColoTensor, ColoTensorSpec -from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy, \ - ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig +from colossalai.nn.parallel.layers import CachedParamMgr, CachedEmbeddingBag, ParallelCachedEmbeddingBag, EvictionStrategy, \ + ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig from typing import List NUM_EMBED, EMBED_DIM = 10, 8 @@ -106,13 +106,13 @@ def test_reorder_with_freq(): def test_freq_aware_embed(use_LFU: bool): device = torch.device('cuda', 0) evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET - model = FreqAwareEmbeddingBag(NUM_EMBED, - EMBED_DIM, - mode='mean', - include_last_offset=True, - cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0), - ids_freq_mapping=None, - evict_strategy=evict_strategy).to(device) + model = CachedEmbeddingBag(NUM_EMBED, + EMBED_DIM, + mode='mean', + include_last_offset=True, + cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0), + ids_freq_mapping=None, + evict_strategy=evict_strategy).to(device) assert model.weight.shape[0] == NUM_EMBED ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device), @@ -151,14 +151,14 @@ def test_freq_aware_embed(use_LFU: bool): @pytest.mark.parametrize('init_freq', [True, False]) def test_lfu_strategy(init_freq: bool): # minimal test to check behavior - Bag = FreqAwareEmbeddingBag(5, - 5, - cache_ratio=3 / 5, - buffer_size=0, - pin_weight=True, - ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None, - warmup_ratio=1.0, - evict_strategy=EvictionStrategy.LFU) + Bag = CachedEmbeddingBag(5, + 5, + cache_ratio=3 / 5, + buffer_size=0, + pin_weight=True, + ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None, + warmup_ratio=1.0, + evict_strategy=EvictionStrategy.LFU) # print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map) offsets = torch.tensor([0], device="cuda:0") @@ -233,7 +233,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): _weight = torch.cat([weight_table1, weight_table2], 0) else: _weight = weight_table3 - model = ParallelFreqAwareEmbeddingBagTablewise( + model = ParallelCachedEmbeddingBagTablewise( embedding_bag_config_list, embedding_dim=5, _weight=_weight, @@ -300,7 +300,7 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size): coloweight.set_process_group(ProcessGroup(tp_degree=world_size)) coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D)) - model = ParallelFreqAwareEmbeddingBag.from_pretrained( + model = ParallelCachedEmbeddingBag.from_pretrained( coloweight, include_last_offset=True, freeze=False,