mirror of https://github.com/hpcaitech/ColossalAI
[embedding] non-blocking cpu-gpu copy (#1647)
parent
0767f67a0f
commit
04443605a5
|
@ -38,6 +38,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
buffer_size: int = 0,
|
||||
pin_weight: bool = False,
|
||||
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
|
||||
async_copy: bool = False,
|
||||
) -> None:
|
||||
super(CachedParamMgr, self).__init__()
|
||||
self.buffer_size = buffer_size
|
||||
|
@ -58,6 +59,11 @@ class CachedParamMgr(torch.nn.Module):
|
|||
|
||||
self._evict_strategy = evict_strategy
|
||||
|
||||
self._async_copy = async_copy
|
||||
|
||||
if self._async_copy:
|
||||
print('use async copy')
|
||||
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
# cache_row_idx -> frequency, freq of the cache rows.
|
||||
# classic lfu cache. evict the minimal freq value row in cuda cache.
|
||||
|
@ -312,6 +318,18 @@ class CachedParamMgr(torch.nn.Module):
|
|||
cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA
|
||||
"""
|
||||
evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num
|
||||
|
||||
cpu_row_idxs_copy = cpu_row_idxs.cpu()
|
||||
|
||||
# move evict in rows to gpu
|
||||
if self._async_copy:
|
||||
if self.buffer_size == 0:
|
||||
rows_cpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory()
|
||||
evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device())
|
||||
evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True)
|
||||
else:
|
||||
raise NotImplemented
|
||||
|
||||
if evict_num > 0:
|
||||
with Timer() as timer:
|
||||
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
||||
|
@ -323,12 +341,24 @@ class CachedParamMgr(torch.nn.Module):
|
|||
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
||||
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||
|
||||
# move evict out rows to cpu
|
||||
if self._async_copy:
|
||||
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||
-1).index_select(0, evict_gpu_row_idxs)
|
||||
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
|
||||
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||
|
||||
elif self._evict_strategy == EvictionStrategy.LFU:
|
||||
backup_freqs = self.freq_cnter[invalid_idxs].clone()
|
||||
self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize)
|
||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||
if self._async_copy:
|
||||
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||
-1).index_select(0, evict_gpu_row_idxs)
|
||||
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
|
||||
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
||||
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs)
|
||||
|
||||
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
||||
|
@ -341,8 +371,13 @@ class CachedParamMgr(torch.nn.Module):
|
|||
tgt=self.weight.view(self.num_embeddings, -1))
|
||||
else:
|
||||
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
|
||||
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs).cpu()
|
||||
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows)
|
||||
# TODO async gpu -> cpu
|
||||
if self._async_copy:
|
||||
pass
|
||||
else:
|
||||
evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||
-1).index_select(0, evict_gpu_row_idxs).cpu()
|
||||
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu)
|
||||
|
||||
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
|
||||
self.inverted_cached_idx.index_fill_(0, evict_info, -1)
|
||||
|
@ -359,13 +394,20 @@ class CachedParamMgr(torch.nn.Module):
|
|||
# Here also allocate extra memory on CUDA. #cpu_row_idxs
|
||||
if self.buffer_size > 0:
|
||||
self.limit_buff_index_copyer.index_copy(0,
|
||||
src_index=cpu_row_idxs.cpu(),
|
||||
src_index=cpu_row_idxs_copy,
|
||||
tgt_index=slots,
|
||||
src=self.weight.view(self.num_embeddings, -1),
|
||||
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
||||
else:
|
||||
rows = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda()
|
||||
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, rows)
|
||||
# TODO async copy cpu -> gpu
|
||||
if self._async_copy:
|
||||
pass
|
||||
else:
|
||||
evict_in_rows_gpu = self.weight.view(self.num_embeddings,
|
||||
-1).index_select(0, cpu_row_idxs_copy).cuda()
|
||||
|
||||
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, evict_in_rows_gpu)
|
||||
|
||||
slot_offsets = slots
|
||||
self.cached_idx_map[slots] = cpu_row_idxs
|
||||
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets)
|
||||
|
@ -453,4 +495,4 @@ class CachedParamMgr(torch.nn.Module):
|
|||
self._cuda_available_row_num -= 1
|
||||
|
||||
self._cpu_to_cuda_numel += self.embedding_dim
|
||||
self._cpu_to_cuda_elpase += timer.elapsed
|
||||
self._cpu_to_cuda_elpase += timer.elapsed
|
||||
|
|
Loading…
Reference in New Issue