From c638bec02887dd6f19f450174d6d9b04f4f3dd05 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 27 Sep 2022 14:37:03 +0800 Subject: [PATCH] [embedding] polish async copy (#1657) --- .../layers/cache_embedding/cache_mgr.py | 43 +++++++++++++------ .../cache_embedding/freq_aware_embedding.py | 3 ++ .../parallel_freq_aware_embedding.py | 2 +- 3 files changed, 34 insertions(+), 14 deletions(-) diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index d89290145..5babeb009 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -15,6 +15,23 @@ class EvictionStrategy(Enum): DATASET = 2 +def _wait_for_data(t, stream: Optional[torch.cuda.streams.Stream]) -> None: + if stream is None: + return + torch.cuda.current_stream().wait_stream(stream) + # As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html, + # PyTorch uses the "caching allocator" for memroy allocation for tensors. When a tensor is + # freed, its memory is likely to be reused by newly constructed tenosrs. By default, + # this allocator traces whether a tensor is still in use by only the CUDA stream where it + # was created. When a tensor is used by additional CUDA streams, we need to call record_stream + # to tell the allocator about all these streams. Otherwise, the allocator might free the + # underlying memory of the tensor once it is no longer used by the creator stream. This is + # a notable programming trick when we write programs using multi CUDA streams. + cur_stream = torch.cuda.current_stream() + assert isinstance(t, torch.Tensor) + t.record_stream(cur_stream) + + class CachedParamMgr(torch.nn.Module): """ Manage Embedding Weights on CPU and CUDA memory uses a software cache. @@ -37,7 +54,7 @@ class CachedParamMgr(torch.nn.Module): weight: torch.Tensor, cuda_row_num: int = 0, buffer_size: int = 0, - pin_weight: bool = False, + pin_weight: bool = True, evict_strategy: EvictionStrategy = EvictionStrategy.DATASET, async_copy: bool = False, ) -> None: @@ -62,6 +79,8 @@ class CachedParamMgr(torch.nn.Module): self._async_copy = async_copy if self._async_copy: + self._memcpy_stream = torch.cuda.Stream() + print('use async copy') if self._evict_strategy == EvictionStrategy.LFU: @@ -350,11 +369,10 @@ class CachedParamMgr(torch.nn.Module): # move evict in rows to gpu if self._async_copy: if self.buffer_size == 0: - idxslt_stream = torch.cuda.Stream() - with torch.cuda.stream(idxslt_stream): - rows_cpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory() - # evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device()) - # evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True) + evict_in_rows_gpu = self.weight.view(self.num_embeddings, + -1).index_select(0, cpu_row_idxs_copy).pin_memory() + with torch.cuda.stream(self._memcpy_stream): + evict_in_rows_gpu = evict_in_rows_gpu.to(torch.cuda.current_device(), non_blocking=True) else: raise NotImplemented @@ -378,7 +396,8 @@ class CachedParamMgr(torch.nn.Module): evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs) evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) - evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) + with torch.cuda.stream(None): + evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) elif self._evict_strategy == EvictionStrategy.LFU: @@ -393,7 +412,8 @@ class CachedParamMgr(torch.nn.Module): evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs) evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) - evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) + with torch.cuda.stream(None): + evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) with self.timer("3_1_2_find_evict_index_copy") as timer: self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs) @@ -410,7 +430,7 @@ class CachedParamMgr(torch.nn.Module): # allocate tmp memory on CPU and copy rows on CUDA to CPU. # TODO async gpu -> cpu if self._async_copy: - pass + _wait_for_data(evict_out_rows_cpu, None) else: with self.timer("3_2_1_evict_out_index_select") as timer: evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num, @@ -445,10 +465,7 @@ class CachedParamMgr(torch.nn.Module): tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) else: if self._async_copy: - torch.cuda.current_stream().wait_stream(idxslt_stream) - evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device()) - evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True) - pass + _wait_for_data(evict_in_rows_gpu, self._memcpy_stream) else: with self.timer("3_4_1_evict_in_index_select") as timer: # narrow index select to a subset of self.weight diff --git a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py index 356d77bd2..f4704e09e 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py @@ -66,6 +66,9 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight) self.cache_op = True + def set_cache_mgr_async_copy(self, flag): + self.cache_weight_mgr._async_copy = flag + def _weight_alloc(self, dtype, device): weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device) with torch.no_grad(): diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py index 28e6e0575..f64917b45 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py @@ -114,7 +114,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): cuda_row_num: int = 100_000, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio: float = 0.7, - buffer_size: int = 50_000, + buffer_size: int = 0, ) -> 'ParallelFreqAwareEmbeddingBag': rows, cols = embedding.shape embedding_bag = cls(rows,