mirror of https://github.com/hpcaitech/ColossalAI
[embedding] non-blocking cpu-gpu copy (#1647)
parent
0767f67a0f
commit
04443605a5
|
@ -38,6 +38,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
buffer_size: int = 0,
|
buffer_size: int = 0,
|
||||||
pin_weight: bool = False,
|
pin_weight: bool = False,
|
||||||
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
|
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
|
||||||
|
async_copy: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(CachedParamMgr, self).__init__()
|
super(CachedParamMgr, self).__init__()
|
||||||
self.buffer_size = buffer_size
|
self.buffer_size = buffer_size
|
||||||
|
@ -58,6 +59,11 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
|
|
||||||
self._evict_strategy = evict_strategy
|
self._evict_strategy = evict_strategy
|
||||||
|
|
||||||
|
self._async_copy = async_copy
|
||||||
|
|
||||||
|
if self._async_copy:
|
||||||
|
print('use async copy')
|
||||||
|
|
||||||
if self._evict_strategy == EvictionStrategy.LFU:
|
if self._evict_strategy == EvictionStrategy.LFU:
|
||||||
# cache_row_idx -> frequency, freq of the cache rows.
|
# cache_row_idx -> frequency, freq of the cache rows.
|
||||||
# classic lfu cache. evict the minimal freq value row in cuda cache.
|
# classic lfu cache. evict the minimal freq value row in cuda cache.
|
||||||
|
@ -312,6 +318,18 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA
|
cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA
|
||||||
"""
|
"""
|
||||||
evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num
|
evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num
|
||||||
|
|
||||||
|
cpu_row_idxs_copy = cpu_row_idxs.cpu()
|
||||||
|
|
||||||
|
# move evict in rows to gpu
|
||||||
|
if self._async_copy:
|
||||||
|
if self.buffer_size == 0:
|
||||||
|
rows_cpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory()
|
||||||
|
evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device())
|
||||||
|
evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True)
|
||||||
|
else:
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
if evict_num > 0:
|
if evict_num > 0:
|
||||||
with Timer() as timer:
|
with Timer() as timer:
|
||||||
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
||||||
|
@ -323,12 +341,24 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
||||||
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
||||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||||
|
|
||||||
|
# move evict out rows to cpu
|
||||||
|
if self._async_copy:
|
||||||
|
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||||
|
-1).index_select(0, evict_gpu_row_idxs)
|
||||||
|
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
|
||||||
|
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
||||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||||
|
|
||||||
elif self._evict_strategy == EvictionStrategy.LFU:
|
elif self._evict_strategy == EvictionStrategy.LFU:
|
||||||
backup_freqs = self.freq_cnter[invalid_idxs].clone()
|
backup_freqs = self.freq_cnter[invalid_idxs].clone()
|
||||||
self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize)
|
self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize)
|
||||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||||
|
if self._async_copy:
|
||||||
|
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||||
|
-1).index_select(0, evict_gpu_row_idxs)
|
||||||
|
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
|
||||||
|
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
||||||
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs)
|
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs)
|
||||||
|
|
||||||
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
||||||
|
@ -341,8 +371,13 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
tgt=self.weight.view(self.num_embeddings, -1))
|
tgt=self.weight.view(self.num_embeddings, -1))
|
||||||
else:
|
else:
|
||||||
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
|
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
|
||||||
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs).cpu()
|
# TODO async gpu -> cpu
|
||||||
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows)
|
if self._async_copy:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||||
|
-1).index_select(0, evict_gpu_row_idxs).cpu()
|
||||||
|
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu)
|
||||||
|
|
||||||
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
|
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
|
||||||
self.inverted_cached_idx.index_fill_(0, evict_info, -1)
|
self.inverted_cached_idx.index_fill_(0, evict_info, -1)
|
||||||
|
@ -359,13 +394,20 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
# Here also allocate extra memory on CUDA. #cpu_row_idxs
|
# Here also allocate extra memory on CUDA. #cpu_row_idxs
|
||||||
if self.buffer_size > 0:
|
if self.buffer_size > 0:
|
||||||
self.limit_buff_index_copyer.index_copy(0,
|
self.limit_buff_index_copyer.index_copy(0,
|
||||||
src_index=cpu_row_idxs.cpu(),
|
src_index=cpu_row_idxs_copy,
|
||||||
tgt_index=slots,
|
tgt_index=slots,
|
||||||
src=self.weight.view(self.num_embeddings, -1),
|
src=self.weight.view(self.num_embeddings, -1),
|
||||||
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
||||||
else:
|
else:
|
||||||
rows = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda()
|
# TODO async copy cpu -> gpu
|
||||||
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, rows)
|
if self._async_copy:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
evict_in_rows_gpu = self.weight.view(self.num_embeddings,
|
||||||
|
-1).index_select(0, cpu_row_idxs_copy).cuda()
|
||||||
|
|
||||||
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, evict_in_rows_gpu)
|
||||||
|
|
||||||
slot_offsets = slots
|
slot_offsets = slots
|
||||||
self.cached_idx_map[slots] = cpu_row_idxs
|
self.cached_idx_map[slots] = cpu_row_idxs
|
||||||
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets)
|
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets)
|
||||||
|
|
Loading…
Reference in New Issue