From 9a9ef65313f0fcdc4ab053348d5b3e02fa2e8c6a Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 30 Aug 2022 14:50:02 +0800 Subject: [PATCH] [FAW] cpu caching operations (#1520) --- .../layers/cache_embedding/cache_mgr.py | 84 ++++++++++++------- .../cache_embedding/freq_aware_embedding.py | 6 +- .../parallel_freq_aware_embedding.py | 47 +++++------ tests/test_layers/test_cache_embedding.py | 13 +-- 4 files changed, 86 insertions(+), 64 deletions(-) diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index 296d18689..08c206ec9 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -30,6 +30,7 @@ class CachedParamMgr(torch.nn.Module): `EvictionStrategy.LFU`: use the least frequently used cache. `EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume. Defaults to EvictionStrategy.DATASET. + use_cpu_caching (bool, optional): use cpu to execute cache indexing. It is slower than use gpu. """ def __init__( @@ -39,6 +40,7 @@ class CachedParamMgr(torch.nn.Module): buffer_size: int = 50_000, pin_weight: bool = False, evict_strategy: EvictionStrategy = EvictionStrategy.DATASET, + use_cpu_caching=False, ) -> None: super(CachedParamMgr, self).__init__() self.buffer_size = buffer_size @@ -48,6 +50,13 @@ class CachedParamMgr(torch.nn.Module): self.pin_weight = pin_weight self.elem_size_in_byte = weight.element_size() + self._cpu_caching = use_cpu_caching + + if self._cpu_caching: + self._cache_dev = torch.device('cpu') + else: + self._cache_dev = torch.cuda.current_device() + # weight configure self._init_weight(weight) @@ -62,10 +71,15 @@ class CachedParamMgr(torch.nn.Module): if self._evict_strategy == EvictionStrategy.LFU: # cache_row_idx -> frequency, freq of the cache rows. # classic lfu cache. evict the minimal freq value row in cuda cache. - self.register_buffer("freq_cnter", - torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), - dtype=torch.long).fill_(sys.maxsize), - persistent=False) + if self._cpu_caching: + self.freq_cnter = torch.empty(self.cuda_row_num, device=self._cache_dev, + dtype=torch.long).fill_(sys.maxsize) + + else: + self.register_buffer("freq_cnter", + torch.empty(self.cuda_row_num, device=self._cache_dev, + dtype=torch.long).fill_(sys.maxsize), + persistent=False) def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor: """_find_evict_gpu_idxs @@ -105,26 +119,32 @@ class CachedParamMgr(torch.nn.Module): self.weight = weight.pin_memory() if self.pin_weight else weight # map original id to new id with respect to frequency # id -> cpu_row_idx - self.register_buffer( - "idx_map", - torch.arange(self.num_embeddings, dtype=torch.long, device=torch.cuda.current_device()), - persistent=False, - ) - # cached_idx_map: gpu_row_idx -> cpu_row_idx - self.register_buffer("cached_idx_map", - torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), - dtype=torch.long).fill_(-1), - persistent=False) + if self._cpu_caching: + self.idx_map = torch.arange(self.num_embeddings, dtype=torch.long, device=self._cache_dev) + self.cached_idx_map = torch.empty(self.cuda_row_num, device=self._cache_dev, dtype=torch.long).fill_(-1) + self.inverted_cached_idx = torch.zeros(self.num_embeddings, device=self._cache_dev, + dtype=torch.long).fill_(-1) + else: + self.register_buffer( + "idx_map", + torch.arange(self.num_embeddings, dtype=torch.long, device=self._cache_dev), + persistent=False, + ) - # cpu_row_id -> gpu_row_idx. - # gpu_row_idx as -1 means cpu_row_id not in CUDA. - self.register_buffer("inverted_cached_idx", - torch.zeros(self.num_embeddings, device=torch.cuda.current_device(), - dtype=torch.long).fill_(-1), - persistent=False) + # cached_idx_map: gpu_row_idx -> cpu_row_idx + self.register_buffer("cached_idx_map", + torch.empty(self.cuda_row_num, device=self._cache_dev, dtype=torch.long).fill_(-1), + persistent=False) - self.evict_backlist = torch.tensor([], device=torch.cuda.current_device()) + # cpu_row_id -> gpu_row_idx. + # gpu_row_idx as -1 means cpu_row_id not in CUDA. + self.register_buffer("inverted_cached_idx", + torch.zeros(self.num_embeddings, device=self._cache_dev, + dtype=torch.long).fill_(-1), + persistent=False) + + self.evict_backlist = torch.tensor([], device=self._cache_dev) # index copy buffer size should less than 10% of cuda weight. if self.buffer_size > 0: @@ -191,24 +211,24 @@ class CachedParamMgr(torch.nn.Module): # extract rows from cpu weight if self._evict_strategy == EvictionStrategy.LFU and ids_freq_mapping is not None: freq_value, preload_cpu_ids = torch.topk(ids_freq_mapping, preload_row_num, dim=0, largest=True) - preload_cuda_row_idxs = torch.arange(preload_row_num).cuda() + preload_cuda_row_idxs = torch.arange(preload_row_num).to(self._cache_dev) else: preload_cpu_ids = torch.arange(preload_row_num) - preload_cuda_row_idxs = preload_cpu_ids.cuda() + preload_cuda_row_idxs = preload_cpu_ids.to(self._cache_dev) if self.buffer_size > 0: self.limit_buff_index_copyer.index_copy(0, src_index=preload_cpu_ids, - tgt_index=preload_cuda_row_idxs, + tgt_index=preload_cuda_row_idxs.cuda(), src=self.weight.view(self.num_embeddings, -1), tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) else: preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_cpu_ids).cuda() - self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs, + self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs.cuda(), preload_rows) # update auxiliary info - self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.cuda() + self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.to(self._cache_dev) self.inverted_cached_idx[preload_cpu_ids] = preload_cuda_row_idxs self._cuda_available_row_num -= preload_row_num @@ -217,7 +237,7 @@ class CachedParamMgr(torch.nn.Module): if ids_freq_mapping is None: self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0) else: - self.freq_cnter[preload_cuda_row_idxs] = freq_value.cuda() + self.freq_cnter[preload_cuda_row_idxs] = freq_value.to(self._cache_dev) print(f'Cache warmup finished cost {timer.elapsed} sec.') @@ -227,7 +247,7 @@ class CachedParamMgr(torch.nn.Module): """ slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1) row_ids = self.cached_idx_map[slots] - rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu() + rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots.cuda()).cpu() self.weight.view(self.num_embeddings, -1).index_copy_(0, row_ids.cpu(), rows) self.cached_idx_map.index_fill_(0, slots, -1) self.inverted_cached_idx.index_fill_(0, row_ids, -1) @@ -276,6 +296,7 @@ class CachedParamMgr(torch.nn.Module): torch.Tensor: indices on the cuda_cached_weight. """ with record_function("(zhg) get unique indices"): + ids = ids.to(self._cache_dev) cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True) assert len(cpu_row_idxs) <= self.cuda_row_num, \ @@ -353,7 +374,8 @@ class CachedParamMgr(torch.nn.Module): tgt=self.weight.view(self.num_embeddings, -1)) else: # allocate tmp memory on CPU and copy rows on CUDA to CPU. - rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs).cpu() + rows = self.cuda_cached_weight.view(self.cuda_row_num, + -1).index_select(0, evict_gpu_row_idxs.cuda()).cpu() self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows) self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1) @@ -372,12 +394,12 @@ class CachedParamMgr(torch.nn.Module): if self.buffer_size > 0: self.limit_buff_index_copyer.index_copy(0, src_index=cpu_row_idxs.cpu(), - tgt_index=slots, + tgt_index=slots.cuda(), src=self.weight.view(self.num_embeddings, -1), tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) else: rows = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda() - self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, rows) + self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots.cuda(), rows) slot_offsets = slots self.cached_idx_map[slots] = cpu_row_idxs self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets) diff --git a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py index d8ecfb611..58352a70d 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py @@ -74,8 +74,8 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): with torch.no_grad(): reorder_ids = self.cache_weight_mgr.prepare_ids(indices) - embeddings = F.embedding_bag(reorder_ids, self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, + embeddings = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, + self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, per_sample_weights, self.include_last_offset, self.padding_idx) if shape_hook is not None: embeddings = shape_hook(embeddings) @@ -119,4 +119,4 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): if self.cache_weight_mgr._cuda_to_cpu_numel > 0: return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \ self.cache_weight_mgr._cuda_to_cpu_elapse - return 0 \ No newline at end of file + return 0 diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py index 39f55d37a..8213926ae 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py @@ -8,6 +8,7 @@ from colossalai.nn._ops._utils import dual_all_to_all from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor from .cache_mgr import CachedParamMgr, EvictionStrategy + def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: if world_size == 1: return 0, embedding_dim, True @@ -29,27 +30,25 @@ def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): - def __init__( - self, - num_embeddings, - embedding_dim, - padding_idx=None, - max_norm=None, - norm_type=2., - scale_grad_by_freq=False, - sparse=False, - _weight=None, - mode='mean', - include_last_offset=False, - dtype=None, - device=None, - cuda_row_num=0, - ids_freq_mapping=None, - warmup_ratio=0.7, - buffer_size=50_000, - pin_weight=False, - evict_strategy: EvictionStrategy = EvictionStrategy.DATASET - ): + def __init__(self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2., + scale_grad_by_freq=False, + sparse=False, + _weight=None, + mode='mean', + include_last_offset=False, + dtype=None, + device=None, + cuda_row_num=0, + ids_freq_mapping=None, + warmup_ratio=0.7, + buffer_size=50_000, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.DATASET): self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() @@ -60,7 +59,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): super(ParallelFreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping, - warmup_ratio, buffer_size, pin_weight,evict_strategy) + warmup_ratio, buffer_size, pin_weight, evict_strategy) def _weight_alloc(self, dtype, device): weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype) @@ -77,8 +76,8 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): with torch.no_grad(): reorder_ids = self.cache_weight_mgr.prepare_ids(indices) - output_shard = F.embedding_bag(reorder_ids, self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, + output_shard = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, + self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, per_sample_weights, self.include_last_offset, self.padding_idx) if shape_hook is not None: diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py index 2f7ee579e..50fbb732c 100644 --- a/tests/test_layers/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -83,15 +83,16 @@ def test_reorder_with_freq(): chunkid.append(idx // chunk_size) offset_in_chunk.append(idx % chunk_size) - chunkid = torch.tensor(chunkid, dtype=torch.long, device=torch.cuda.current_device()) - offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=torch.cuda.current_device()) + dev = torch.device('cuda') + chunkid = torch.tensor(chunkid, dtype=torch.long, device=dev) + offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=dev) weight = torch.rand(num_embed, 2) - mgr = CachedParamMgr(weight, num_chunk) + mgr = CachedParamMgr(weight, num_chunk, use_cpu_caching=dev.type == 'cpu') mgr.reorder(idx_map) - indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=torch.cuda.current_device())) + indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=dev)) mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode='floor') mgr_offsets = torch.remainder(indices, chunk_size) assert torch.allclose(chunkid, mgr_chunk_id), f"chunk id: {chunkid}, mgr: {mgr_chunk_id}" @@ -280,6 +281,6 @@ def test_parallel_freq_aware_embed(world_size): if __name__ == '__main__': - # test_freq_aware_embed(True) + test_freq_aware_embed(True) # test_parallel_freq_aware_embed(2) - test_lfu_strategy(False) + # test_lfu_strategy(False)