From 04443605a5a5172f5ab1bacf7ff457fc167fbdf8 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Mon, 26 Sep 2022 14:57:57 +0800 Subject: [PATCH] [embedding] non-blocking cpu-gpu copy (#1647) --- .../layers/cache_embedding/cache_mgr.py | 54 ++++++++++++++++--- 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index 6f591ad44..127270da0 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -38,6 +38,7 @@ class CachedParamMgr(torch.nn.Module): buffer_size: int = 0, pin_weight: bool = False, evict_strategy: EvictionStrategy = EvictionStrategy.DATASET, + async_copy: bool = False, ) -> None: super(CachedParamMgr, self).__init__() self.buffer_size = buffer_size @@ -58,6 +59,11 @@ class CachedParamMgr(torch.nn.Module): self._evict_strategy = evict_strategy + self._async_copy = async_copy + + if self._async_copy: + print('use async copy') + if self._evict_strategy == EvictionStrategy.LFU: # cache_row_idx -> frequency, freq of the cache rows. # classic lfu cache. evict the minimal freq value row in cuda cache. @@ -312,6 +318,18 @@ class CachedParamMgr(torch.nn.Module): cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA """ evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num + + cpu_row_idxs_copy = cpu_row_idxs.cpu() + + # move evict in rows to gpu + if self._async_copy: + if self.buffer_size == 0: + rows_cpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory() + evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device()) + evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True) + else: + raise NotImplemented + if evict_num > 0: with Timer() as timer: mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) @@ -323,12 +341,24 @@ class CachedParamMgr(torch.nn.Module): backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone() self.cached_idx_map.index_fill_(0, invalid_idxs, -2) evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) + + # move evict out rows to cpu + if self._async_copy: + evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, + -1).index_select(0, evict_gpu_row_idxs) + evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) + evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) elif self._evict_strategy == EvictionStrategy.LFU: backup_freqs = self.freq_cnter[invalid_idxs].clone() self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize) evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) + if self._async_copy: + evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, + -1).index_select(0, evict_gpu_row_idxs) + evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) + evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs) evict_info = self.cached_idx_map[evict_gpu_row_idxs] @@ -341,8 +371,13 @@ class CachedParamMgr(torch.nn.Module): tgt=self.weight.view(self.num_embeddings, -1)) else: # allocate tmp memory on CPU and copy rows on CUDA to CPU. - rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs).cpu() - self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows) + # TODO async gpu -> cpu + if self._async_copy: + pass + else: + evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num, + -1).index_select(0, evict_gpu_row_idxs).cpu() + self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu) self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1) self.inverted_cached_idx.index_fill_(0, evict_info, -1) @@ -359,13 +394,20 @@ class CachedParamMgr(torch.nn.Module): # Here also allocate extra memory on CUDA. #cpu_row_idxs if self.buffer_size > 0: self.limit_buff_index_copyer.index_copy(0, - src_index=cpu_row_idxs.cpu(), + src_index=cpu_row_idxs_copy, tgt_index=slots, src=self.weight.view(self.num_embeddings, -1), tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) else: - rows = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda() - self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, rows) + # TODO async copy cpu -> gpu + if self._async_copy: + pass + else: + evict_in_rows_gpu = self.weight.view(self.num_embeddings, + -1).index_select(0, cpu_row_idxs_copy).cuda() + + self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, evict_in_rows_gpu) + slot_offsets = slots self.cached_idx_map[slots] = cpu_row_idxs self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets) @@ -453,4 +495,4 @@ class CachedParamMgr(torch.nn.Module): self._cuda_available_row_num -= 1 self._cpu_to_cuda_numel += self.embedding_dim - self._cpu_to_cuda_elpase += timer.elapsed \ No newline at end of file + self._cpu_to_cuda_elpase += timer.elapsed