From 0aad53c62b5b229d31306e1b50114a2c27a78462 Mon Sep 17 00:00:00 2001 From: Geng Zhang <34452939+zxgx@users.noreply.github.com> Date: Tue, 23 Aug 2022 17:38:24 +0800 Subject: [PATCH] [FCE] update interface for frequency statistics in FreqCacheEmbedding (#1462) --- .../layers/cache_embedding/cache_mgr.py | 12 ++++++++---- .../cache_embedding/freq_aware_embedding.py | 17 +++++++++-------- .../parallel_freq_aware_embedding.py | 16 ++++++++-------- tests/test_layers/test_cache_embedding.py | 10 +++++----- 4 files changed, 30 insertions(+), 25 deletions(-) diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index 21d58ab41..fbe24caca 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -14,12 +14,17 @@ class CachedParamMgr(torch.nn.Module): During training, GPU needs to transmit rows between CPU and GPU. """ - def __init__(self, weight: torch.Tensor, cuda_row_num: int = 0, buffer_size: int = 50_000) -> None: + def __init__(self, + weight: torch.Tensor, + cuda_row_num: int = 0, + buffer_size: int = 50_000, + pin_weight=False) -> None: super(CachedParamMgr, self).__init__() self.buffer_size = buffer_size self.num_embeddings, self.embedding_dim = weight.shape self.cuda_row_num = cuda_row_num self._cuda_available_row_num = self.cuda_row_num + self.pin_weight = pin_weight self.elem_size_in_byte = weight.element_size() @@ -43,8 +48,7 @@ class CachedParamMgr(torch.nn.Module): dtype=weight.dtype)) # pin memory cpu for higher CPU-GPU copy bandwidth - self.weight = weight.contiguous().cpu().pin_memory() - + self.weight = weight.pin_memory() if self.pin_weight else weight # map original id to new id with respect to frequency # id -> cpu_row_idx self.register_buffer( @@ -109,7 +113,7 @@ class CachedParamMgr(torch.nn.Module): warmup_ratio (float): the amount of chunks preloaded in cuda cache """ if ids_freq_mapping is not None: - tmp_idx = torch.argsort(torch.from_numpy(ids_freq_mapping).cuda(), descending=True) + tmp_idx = torch.argsort(ids_freq_mapping, descending=True) sorted_idx = torch.argsort(tmp_idx) self.idx_map.data.copy_(sorted_idx) diff --git a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py index 7489081ea..ecf890cf0 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py @@ -27,20 +27,19 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): ids_freq_mapping=None, warmup_ratio=0.7, buffer_size=50_000, + pin_weight=False, ): super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, mode, include_last_offset) if _weight is None: _weight = self._weight_alloc(dtype, device) - else: - _weight = _weight # configure weight & cache - self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size) + self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight) def _weight_alloc(self, dtype, device): - weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device, pin_memory=True) + weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device) with torch.no_grad(): weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings) if self.padding_idx is not None: @@ -52,7 +51,8 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): cuda_row_num: int, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7, - buffer_size=50_000): + buffer_size=50_000, + pin_weight=False): """ Called after initialized. Reorder the weight rows according to the ids_freq_mapping. @@ -63,17 +63,18 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): ids_freq_mapping (List[int]): a list, idx is id number, value is freq warmup_ratio (float): the amount of rows preloaded in cuda cache """ - self.cache_weight_mgr = CachedParamMgr(weight, cuda_row_num, buffer_size) + self.cache_weight_mgr = CachedParamMgr(weight, cuda_row_num, buffer_size, pin_weight) self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) - def forward(self, indices, offsets=None, per_sample_weights=None): + def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None): with torch.no_grad(): reorder_ids = self.cache_weight_mgr.prepare_ids(indices) embeddings = F.embedding_bag(reorder_ids, self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, per_sample_weights, self.include_last_offset, self.padding_idx) - + if shape_hook is not None: + embeddings = shape_hook(embeddings) return embeddings @property diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py index d7a51eb78..62f9df37f 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py @@ -3,8 +3,6 @@ import torch.nn.functional as F from typing import List, Optional, Iterator, Tuple from .freq_aware_embedding import FreqAwareEmbeddingBag -from .cache_mgr import CachedParamMgr -from torch.nn.parameter import Parameter from colossalai.nn._ops._utils import dual_all_to_all from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor @@ -49,6 +47,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): ids_freq_mapping=None, warmup_ratio=0.7, buffer_size=50_000, + pin_weight=False, ): self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() @@ -60,17 +59,18 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): super(ParallelFreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping, - warmup_ratio, buffer_size) + warmup_ratio, buffer_size, pin_weight) def _weight_alloc(self, dtype, device): + weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype) + with torch.no_grad(): + weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings) + if self.padding_idx is not None: + weight[self.padding_idx].fill_(0) colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size), dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]), compute_attr=ComputePattern.TP1D) - return ColoTensor.from_torch_tensor(torch.empty(self.num_embeddings, - self.embedding_dim_per_partition, - device=device, - dtype=dtype), - spec=colo_tensor_spec) + return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec) def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1): with torch.no_grad(): diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py index cf3500694..f238e51e8 100644 --- a/tests/test_layers/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -44,7 +44,7 @@ def synthesize_1d_sparse_feature( def test_cachemgr(): model = torch.nn.EmbeddingBag(10000, 128) # 10 chunks, 5 in cuda - mgr = CachedParamMgr(model.weight, 5) + mgr = CachedParamMgr(model.weight.detach(), 5) assert mgr.cuda_row_num == 5 mgr._admit(1) @@ -74,8 +74,8 @@ def test_reorder_with_freq(): chunk_size = 1 num_chunk = 5 - idx_map = np.random.randint(10000, size=(num_embed,)) - sorted_idx = np.flipud(np.argsort(idx_map)).tolist() + idx_map = torch.randint(10000, size=(num_embed,)) + sorted_idx = torch.argsort(idx_map, descending=True).tolist() chunkid, offset_in_chunk = [], [] for i in range(num_embed): idx = sorted_idx.index(i) @@ -231,6 +231,6 @@ def test_parallel_freq_aware_embed(world_size): if __name__ == '__main__': - # test_cachemgr() + test_cachemgr() # test_freq_aware_embed() - test_parallel_freq_aware_embed(2) + # test_parallel_freq_aware_embed(2)