Browse Source

[embedding] non-blocking cpu-gpu copy (#1647)

pull/1639/head
Jiarui Fang 2 years ago committed by GitHub
parent
commit
04443605a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 54
      colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py

54
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py

@ -38,6 +38,7 @@ class CachedParamMgr(torch.nn.Module):
buffer_size: int = 0, buffer_size: int = 0,
pin_weight: bool = False, pin_weight: bool = False,
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET, evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
async_copy: bool = False,
) -> None: ) -> None:
super(CachedParamMgr, self).__init__() super(CachedParamMgr, self).__init__()
self.buffer_size = buffer_size self.buffer_size = buffer_size
@ -58,6 +59,11 @@ class CachedParamMgr(torch.nn.Module):
self._evict_strategy = evict_strategy self._evict_strategy = evict_strategy
self._async_copy = async_copy
if self._async_copy:
print('use async copy')
if self._evict_strategy == EvictionStrategy.LFU: if self._evict_strategy == EvictionStrategy.LFU:
# cache_row_idx -> frequency, freq of the cache rows. # cache_row_idx -> frequency, freq of the cache rows.
# classic lfu cache. evict the minimal freq value row in cuda cache. # classic lfu cache. evict the minimal freq value row in cuda cache.
@ -312,6 +318,18 @@ class CachedParamMgr(torch.nn.Module):
cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA
""" """
evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num
cpu_row_idxs_copy = cpu_row_idxs.cpu()
# move evict in rows to gpu
if self._async_copy:
if self.buffer_size == 0:
rows_cpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory()
evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device())
evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True)
else:
raise NotImplemented
if evict_num > 0: if evict_num > 0:
with Timer() as timer: with Timer() as timer:
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
@ -323,12 +341,24 @@ class CachedParamMgr(torch.nn.Module):
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone() backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
self.cached_idx_map.index_fill_(0, invalid_idxs, -2) self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
# move evict out rows to cpu
if self._async_copy:
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
-1).index_select(0, evict_gpu_row_idxs)
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
elif self._evict_strategy == EvictionStrategy.LFU: elif self._evict_strategy == EvictionStrategy.LFU:
backup_freqs = self.freq_cnter[invalid_idxs].clone() backup_freqs = self.freq_cnter[invalid_idxs].clone()
self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize) self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize)
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
if self._async_copy:
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
-1).index_select(0, evict_gpu_row_idxs)
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs) self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs)
evict_info = self.cached_idx_map[evict_gpu_row_idxs] evict_info = self.cached_idx_map[evict_gpu_row_idxs]
@ -341,8 +371,13 @@ class CachedParamMgr(torch.nn.Module):
tgt=self.weight.view(self.num_embeddings, -1)) tgt=self.weight.view(self.num_embeddings, -1))
else: else:
# allocate tmp memory on CPU and copy rows on CUDA to CPU. # allocate tmp memory on CPU and copy rows on CUDA to CPU.
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs).cpu() # TODO async gpu -> cpu
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows) if self._async_copy:
pass
else:
evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num,
-1).index_select(0, evict_gpu_row_idxs).cpu()
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu)
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1) self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
self.inverted_cached_idx.index_fill_(0, evict_info, -1) self.inverted_cached_idx.index_fill_(0, evict_info, -1)
@ -359,13 +394,20 @@ class CachedParamMgr(torch.nn.Module):
# Here also allocate extra memory on CUDA. #cpu_row_idxs # Here also allocate extra memory on CUDA. #cpu_row_idxs
if self.buffer_size > 0: if self.buffer_size > 0:
self.limit_buff_index_copyer.index_copy(0, self.limit_buff_index_copyer.index_copy(0,
src_index=cpu_row_idxs.cpu(), src_index=cpu_row_idxs_copy,
tgt_index=slots, tgt_index=slots,
src=self.weight.view(self.num_embeddings, -1), src=self.weight.view(self.num_embeddings, -1),
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
else: else:
rows = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda() # TODO async copy cpu -> gpu
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, rows) if self._async_copy:
pass
else:
evict_in_rows_gpu = self.weight.view(self.num_embeddings,
-1).index_select(0, cpu_row_idxs_copy).cuda()
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, evict_in_rows_gpu)
slot_offsets = slots slot_offsets = slots
self.cached_idx_map[slots] = cpu_row_idxs self.cached_idx_map[slots] = cpu_row_idxs
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets) self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets)
@ -453,4 +495,4 @@ class CachedParamMgr(torch.nn.Module):
self._cuda_available_row_num -= 1 self._cuda_available_row_num -= 1
self._cpu_to_cuda_numel += self.embedding_dim self._cpu_to_cuda_numel += self.embedding_dim
self._cpu_to_cuda_elpase += timer.elapsed self._cpu_to_cuda_elpase += timer.elapsed

Loading…
Cancel
Save