From 9f3eed66eb99e6aa76e44b9beb0ebc5ca5fbfd8a Mon Sep 17 00:00:00 2001 From: Geng Zhang <34452939+zxgx@users.noreply.github.com> Date: Fri, 12 Aug 2022 15:55:46 +0800 Subject: [PATCH] [FAW] reorganize the inheritance struct of FreqCacheEmbedding (#1448) --- .../layers/cache_embedding/cache_mgr.py | 106 +++++++------ .../cache_embedding/freq_aware_embedding.py | 59 ++++++-- .../parallel_freq_aware_embedding.py | 142 ++++++++---------- tests/test_layers/test_cache_embedding.py | 32 ++-- 4 files changed, 189 insertions(+), 150 deletions(-) diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index 79e188b07..de236b120 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -23,48 +23,61 @@ class CachedParamMgr(torch.nn.Module): self.elem_size_in_byte = weight.element_size() - self.cuda_cached_weight = torch.nn.Parameter( - torch.zeros(self.cuda_row_num, self.embedding_dim, device=torch.cuda.current_device(), dtype=weight.dtype)) - - if weight.device.type == 'cuda': - weight = weight.cpu() - - # pin memory cpu for higher CPU-GPU copy bandwidth - self.cpu_weight = weight.contiguous().pin_memory() - - # map original id to new id with respect to frequency - # id -> cpu_row_idx - self.register_buffer( - "idx_map", - torch.arange(self.num_embeddings, dtype=torch.long, device=torch.cuda.current_device()), - persistent=False, - ) - - # cached_idx_map: gpu_row_idx -> cpu_row_idx - self.register_buffer("cached_idx_map", - torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), - dtype=torch.long).fill_(-1), - persistent=False) - - # cpu_row_id -> gpu_row_idx. - # gpu_row_idx as -1 means cpu_row_id not in CUDA. - self.register_buffer("inverted_cached_idx", - torch.zeros(self.num_embeddings, device=torch.cuda.current_device(), - dtype=torch.long).fill_(-1), - persistent=False) - - self.evict_backlist = torch.tensor([], device=torch.cuda.current_device()) - - # index copy buffer size should less than 10% of cuda weight. - if self.buffer_size > 0: - self.limit_buff_index_copyer = LimitBuffIndexCopyer(self.buffer_size) + # weight configure + self._init_weight(weight) + # Perf log self.num_hits_history = [] self.num_miss_history = [] self.num_write_back_history = [] self.input_id_percent_in_load_chunk = [] self._reset_comm_stats() + def _init_weight(self, weight): + if self.cuda_row_num > 0: + # Enable cache with introducing auxiliary data structures + self.cuda_cached_weight = torch.nn.Parameter( + torch.zeros(self.cuda_row_num, + self.embedding_dim, + device=torch.cuda.current_device(), + dtype=weight.dtype)) + + # pin memory cpu for higher CPU-GPU copy bandwidth + self.weight = weight.contiguous().cpu().pin_memory() + + # map original id to new id with respect to frequency + # id -> cpu_row_idx + self.register_buffer( + "idx_map", + torch.arange(self.num_embeddings, dtype=torch.long, device=torch.cuda.current_device()), + persistent=False, + ) + + # cached_idx_map: gpu_row_idx -> cpu_row_idx + self.register_buffer("cached_idx_map", + torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), + dtype=torch.long).fill_(-1), + persistent=False) + + # cpu_row_id -> gpu_row_idx. + # gpu_row_idx as -1 means cpu_row_id not in CUDA. + self.register_buffer("inverted_cached_idx", + torch.zeros(self.num_embeddings, device=torch.cuda.current_device(), + dtype=torch.long).fill_(-1), + persistent=False) + + self.evict_backlist = torch.tensor([], device=torch.cuda.current_device()) + + # index copy buffer size should less than 10% of cuda weight. + if self.buffer_size > 0: + self.limit_buff_index_copyer = LimitBuffIndexCopyer(self.buffer_size) + + else: + # Disable cache so that FreqCacheEmbedding is compatible with vanilla EmbeddingBag + # self.weight = torch.nn.Parameter(weight) + # self.cuda_cached_weight = self.weight + raise NotImplementedError() + def cpu_weight_data(self, chunk_id: int) -> torch.Tensor: """ access a chunk of CPU weight. @@ -76,9 +89,9 @@ class CachedParamMgr(torch.nn.Module): torch.Tensor: a piece of memory in CPU weight corresponding to chunk id's payload. The tensor is 1-D. """ - return self.cpu_weight.data.view(-1).narrow(0, - int(chunk_id) * self.embedding_dim, - self.embedding_dim).view(1, self.embedding_dim) + return self.weight.data.view(-1).narrow(0, + int(chunk_id) * self.embedding_dim, + self.embedding_dim).view(1, self.embedding_dim) @property def cuda_available_chunk_num(self): @@ -86,7 +99,7 @@ class CachedParamMgr(torch.nn.Module): @torch.no_grad() def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7): - """reorder the cpu_weight according to ids' frequency in dataset before training. + """reorder the weight according to ids' frequency in dataset before training. Also Build the IndexMappingTable, aka index_mapping_table. Execute only once before training. Args: @@ -112,11 +125,10 @@ class CachedParamMgr(torch.nn.Module): self.limit_buff_index_copyer.index_copy(0, src_index=preload_row_ids, tgt_index=preload_slot_ids, - src=self.cpu_weight.view(self.num_embeddings, -1), + src=self.weight.view(self.num_embeddings, -1), tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) else: - preload_chunks = self.cpu_weight.view(self.num_embeddings, -1).index_select(0, - preload_row_ids).cuda() + preload_chunks = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda() self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_chunks) # update auxiliary info @@ -133,7 +145,7 @@ class CachedParamMgr(torch.nn.Module): slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1) chunk_ids = self.cached_idx_map[slots] chunks = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu() - self.cpu_weight.view(self.num_embeddings, -1).index_copy_(0, chunk_ids.cpu(), chunks) + self.weight.view(self.num_embeddings, -1).index_copy_(0, chunk_ids.cpu(), chunks) self.cached_idx_map.index_fill_(0, slots, -1) self.inverted_cached_idx.index_fill_(0, chunk_ids, -1) self._cuda_available_row_num += slots.numel() @@ -237,11 +249,11 @@ class CachedParamMgr(torch.nn.Module): src_index=evict_gpu_row_idxs, tgt_index=evict_info.cpu(), src=self.cuda_cached_weight.view(self.cuda_row_num, -1), - tgt=self.cpu_weight.view(self.num_embeddings, -1)) + tgt=self.weight.view(self.num_embeddings, -1)) else: # allocate tmp memory on CPU and copy rows on CUDA to CPU. rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs).cpu() - self.cpu_weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows) + self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows) self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1) self.inverted_cached_idx.index_fill_(0, evict_info, -1) @@ -259,10 +271,10 @@ class CachedParamMgr(torch.nn.Module): self.limit_buff_index_copyer.index_copy(0, src_index=cpu_row_idxs.cpu(), tgt_index=slots, - src=self.cpu_weight.view(self.num_embeddings, -1), + src=self.weight.view(self.num_embeddings, -1), tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) else: - rows = self.cpu_weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda() + rows = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda() self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, rows) slot_offsets = slots self.cached_idx_map[slots] = cpu_row_idxs diff --git a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py index 95f6996fa..7544c4674 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py @@ -9,15 +9,50 @@ from torch.nn.parameter import Parameter class FreqAwareEmbeddingBag(BaseEmbeddingBag): - def __init__(self, num_embeddings, embedding_dim, dtype=None, *args, **kwargs): - super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, *args, **kwargs) - self._weight = torch.randn(self.num_embeddings, self.embedding_dim, device='cpu', dtype=dtype) + def __init__( + self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2., + scale_grad_by_freq=False, + sparse=False, + _weight=None, + mode='mean', + include_last_offset=False, + dtype=None, + device=None, + cuda_row_num=0, + ids_freq_mapping=None, + warmup_ratio=0.7, + buffer_size=50_000, + ): + super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, + scale_grad_by_freq, sparse, mode, include_last_offset) - def preprocess(self, - cuda_row_num: int, - ids_freq_mapping: Optional[List[int]] = None, - warmup_ratio=0.7, - buffer_size=50_000): + if _weight is None: + _weight = self._weight_alloc(dtype, device) + else: + _weight = _weight + + # configure weight & cache + self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size) + + def _weight_alloc(self, dtype, device): + weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device, pin_memory=True) + with torch.no_grad(): + weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings) + if self.padding_idx is not None: + weight[self.padding_idx].fill_(0) + return weight + + def _preprocess(self, + weight, + cuda_row_num: int, + ids_freq_mapping: Optional[List[int]] = None, + warmup_ratio=0.7, + buffer_size=50_000): """ Called after initialized. Reorder the weight rows according to the ids_freq_mapping. @@ -27,7 +62,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): ids_freq_mapping (List[int]): a list, idx is id number, value is freq warmup_ratio (float): the amount of rows preloaded in cuda cache """ - self.cache_weight_mgr = CachedParamMgr(self._weight, cuda_row_num, buffer_size) + self.cache_weight_mgr = CachedParamMgr(weight, cuda_row_num, buffer_size) self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) def forward(self, indices, offsets=None, per_sample_weights=None): @@ -42,8 +77,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): @property def weight(self): - assert self.cache_weight_mgr is not None - return self.cache_weight_mgr.cpu_weight.narrow(0, 0, self.num_embeddings) + return self.cache_weight_mgr.weight def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: yield 'weight', self.cache_weight_mgr.cuda_cached_weight @@ -51,6 +85,9 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): def parameters(self, recurse: bool = True) -> Iterator[Parameter]: yield self.cache_weight_mgr.cuda_cached_weight + +############################# Perf Log ################################### + @property def num_hits_history(self): return self.cache_weight_mgr.num_hits_history diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py index ee751435a..d7a51eb78 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py @@ -2,12 +2,12 @@ import torch import torch.nn.functional as F from typing import List, Optional, Iterator, Tuple -from .base_embedding import BaseEmbeddingBag +from .freq_aware_embedding import FreqAwareEmbeddingBag from .cache_mgr import CachedParamMgr from torch.nn.parameter import Parameter from colossalai.nn._ops._utils import dual_all_to_all -from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec +from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: @@ -29,71 +29,48 @@ def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: return offset, offset + size_list[rank], False -class ParallelFreqAwareEmbeddingBag(BaseEmbeddingBag): - - def __init__(self, - num_embeddings, - embedding_dim, - padding_idx=None, - max_norm=None, - norm_type=2., - scale_grad_by_freq=False, - sparse=False, - _weight=None, - mode='mean', - include_last_offset=False, - dtype=None, - debug=True): - super(ParallelFreqAwareEmbeddingBag, - self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, - sparse, mode, include_last_offset) +class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): + def __init__( + self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2., + scale_grad_by_freq=False, + sparse=False, + _weight=None, + mode='mean', + include_last_offset=False, + dtype=None, + device=None, + cuda_row_num=0, + ids_freq_mapping=None, + warmup_ratio=0.7, + buffer_size=50_000, + ): self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() - self.debug = debug self.partition_start_index, self.partition_end_index, divisible = get_partition( embedding_dim, self.rank, self.world_size) self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index - if _weight is None: - colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size), - dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]), - compute_attr=ComputePattern.TP1D) - self._weight = ColoParameter.from_torch_tensor(torch.empty(self.num_embeddings, - self.embedding_dim_per_partition, - device='cpu', - dtype=dtype), - requires_grad=True, - spec=colo_tensor_spec) - self.init_parameters() - else: - assert isinstance(_weight, ColoParameter), "initialized weight must in type of ColoParameter" - self._weight = _weight + super(ParallelFreqAwareEmbeddingBag, + self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, + sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping, + warmup_ratio, buffer_size) - @property - def weight(self): - return self.cache_weight_mgr.cpu_weight - - def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: - yield 'weight', self.cache_weight_mgr.cuda_cached_weight - - def parameters(self, recurse: bool = True) -> Iterator[Parameter]: - yield self.cache_weight_mgr.cuda_cached_weight - - @torch.no_grad() - def init_parameters(self): - self._weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings) - if self.padding_idx is not None: - self._weight[self.padding_idx].fill_(0) - - def preprocess(self, - cuda_row_num: int, - ids_freq_mapping: Optional[List[int]] = None, - warmup_ratio: float = 0.7, - buffer_size: int = 50_000): - self.cache_weight_mgr = CachedParamMgr(self._weight, cuda_row_num, buffer_size=buffer_size) - self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) + def _weight_alloc(self, dtype, device): + colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size), + dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]), + compute_attr=ComputePattern.TP1D) + return ColoTensor.from_torch_tensor(torch.empty(self.num_embeddings, + self.embedding_dim_per_partition, + device=device, + dtype=dtype), + spec=colo_tensor_spec) def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1): with torch.no_grad(): @@ -107,29 +84,42 @@ class ParallelFreqAwareEmbeddingBag(BaseEmbeddingBag): output_shard = shape_hook(output_shard) output_full = dual_all_to_all(output_shard, - self._weight.get_process_group(), + self.weight.get_process_group(), scatter_dim=scatter_dim, gather_dim=gather_dim) return output_full @classmethod - def from_pretrained(cls, - embedding: torch.Tensor, - freeze: bool = True, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2., - scale_grad_by_freq: bool = False, - sparse: bool = False, - mode: str = 'mean', - include_last_offset: bool = False, - debug: bool = True, - cuda_row_num: int = 100_000, - ids_freq_mapping: Optional[List[int]] = None, - warmup_ratio: float = 0.7) -> 'ParallelFreqAwareEmbeddingBag': + def from_pretrained( + cls, + embedding: torch.Tensor, + freeze: bool = True, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2., + scale_grad_by_freq: bool = False, + sparse: bool = False, + mode: str = 'mean', + include_last_offset: bool = False, + cuda_row_num: int = 100_000, + ids_freq_mapping: Optional[List[int]] = None, + warmup_ratio: float = 0.7, + buffer_size: int = 50_000, + ) -> 'ParallelFreqAwareEmbeddingBag': rows, cols = embedding.shape - embedding_bag = cls(rows, cols, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, embedding, mode, - include_last_offset, debug) - embedding_bag.preprocess(cuda_row_num, ids_freq_mapping, warmup_ratio) + embedding_bag = cls(rows, + cols, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + embedding, + mode, + include_last_offset, + cuda_row_num=cuda_row_num, + ids_freq_mapping=ids_freq_mapping, + warmup_ratio=warmup_ratio, + buffer_size=buffer_size) embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze return embedding_bag diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py index d7f6e7ee7..cf3500694 100644 --- a/tests/test_layers/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -10,7 +10,8 @@ import torch.multiprocessing as mp import colossalai from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use -from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec +from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \ + ColoTensor, ColoTensorSpec from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag NUM_EMBED, EMBED_DIM = 10, 8 @@ -99,13 +100,12 @@ def test_reorder_with_freq(): def test_freq_aware_embed(): device = torch.device('cuda', 0) - model = FreqAwareEmbeddingBag( - NUM_EMBED, - EMBED_DIM, - mode='mean', - include_last_offset=True, - ).to(device) - model.preprocess(cuda_row_num=BATCH_SIZE * 2, ids_freq_mapping=None) + model = FreqAwareEmbeddingBag(NUM_EMBED, + EMBED_DIM, + mode='mean', + include_last_offset=True, + cuda_row_num=BATCH_SIZE * 2, + ids_freq_mapping=None).to(device) assert model.weight.shape[0] == NUM_EMBED ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device), @@ -159,11 +159,11 @@ def run_parallel_freq_aware_embed(rank, world_size): set_seed(4321) weight = torch.rand(num_embed, embed_dim) - coloweight = ColoParameter(weight.clone().detach().cpu(), requires_grad=False) + coloweight = ColoTensor(weight.clone().detach().cpu(), spec=None) # initialize the tensor spec for the embedding weight parameter, # which is an ColoParameter. - coloweight.process_group = ProcessGroup(tp_degree=world_size) + coloweight.set_process_group(ProcessGroup(tp_degree=world_size)) coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D)) model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight, @@ -171,12 +171,12 @@ def run_parallel_freq_aware_embed(rank, world_size): freeze=False, cuda_row_num=batch_size * 2) - assert model.cache_weight_mgr.cpu_weight.device.type == 'cpu' + assert model.cache_weight_mgr.weight.device.type == 'cpu' assert model.cache_weight_mgr.cuda_cached_weight.requires_grad weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank] - assert torch.allclose( - weight_in_rank, - model.cache_weight_mgr.cpu_weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.cpu_weight}" + print(f"model weight: {model.cache_weight_mgr.weight.shape}, ref weight: {weight_in_rank.shape}") + assert torch.allclose(weight_in_rank, + model.cache_weight_mgr.weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.weight}" optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) @@ -211,7 +211,7 @@ def run_parallel_freq_aware_embed(rank, world_size): ref_optimizer.zero_grad() model.cache_weight_mgr.flush() - weight_list = gather_tensor(model.cache_weight_mgr.cpu_weight.detach().cuda(), rank, world_size) + weight_list = gather_tensor(model.cache_weight_mgr.weight.detach().cuda(), rank, world_size) if rank == 0: recover_weight = torch.cat(weight_list, dim=1) assert torch.allclose(recover_weight, ref_model.weight.detach()), f"{recover_weight - ref_model.weight}" @@ -231,6 +231,6 @@ def test_parallel_freq_aware_embed(world_size): if __name__ == '__main__': + # test_cachemgr() # test_freq_aware_embed() - # test_chunkmgr_admit() test_parallel_freq_aware_embed(2)