|
|
@ -38,6 +38,7 @@ class CachedParamMgr(torch.nn.Module): |
|
|
|
buffer_size: int = 0, |
|
|
|
buffer_size: int = 0, |
|
|
|
pin_weight: bool = False, |
|
|
|
pin_weight: bool = False, |
|
|
|
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET, |
|
|
|
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET, |
|
|
|
|
|
|
|
async_copy: bool = False, |
|
|
|
) -> None: |
|
|
|
) -> None: |
|
|
|
super(CachedParamMgr, self).__init__() |
|
|
|
super(CachedParamMgr, self).__init__() |
|
|
|
self.buffer_size = buffer_size |
|
|
|
self.buffer_size = buffer_size |
|
|
@ -58,6 +59,11 @@ class CachedParamMgr(torch.nn.Module): |
|
|
|
|
|
|
|
|
|
|
|
self._evict_strategy = evict_strategy |
|
|
|
self._evict_strategy = evict_strategy |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._async_copy = async_copy |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self._async_copy: |
|
|
|
|
|
|
|
print('use async copy') |
|
|
|
|
|
|
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU: |
|
|
|
if self._evict_strategy == EvictionStrategy.LFU: |
|
|
|
# cache_row_idx -> frequency, freq of the cache rows. |
|
|
|
# cache_row_idx -> frequency, freq of the cache rows. |
|
|
|
# classic lfu cache. evict the minimal freq value row in cuda cache. |
|
|
|
# classic lfu cache. evict the minimal freq value row in cuda cache. |
|
|
@ -312,6 +318,18 @@ class CachedParamMgr(torch.nn.Module): |
|
|
|
cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA |
|
|
|
cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA |
|
|
|
""" |
|
|
|
""" |
|
|
|
evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num |
|
|
|
evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cpu_row_idxs_copy = cpu_row_idxs.cpu() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# move evict in rows to gpu |
|
|
|
|
|
|
|
if self._async_copy: |
|
|
|
|
|
|
|
if self.buffer_size == 0: |
|
|
|
|
|
|
|
rows_cpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory() |
|
|
|
|
|
|
|
evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device()) |
|
|
|
|
|
|
|
evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True) |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
raise NotImplemented |
|
|
|
|
|
|
|
|
|
|
|
if evict_num > 0: |
|
|
|
if evict_num > 0: |
|
|
|
with Timer() as timer: |
|
|
|
with Timer() as timer: |
|
|
|
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) |
|
|
|
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) |
|
|
@ -323,12 +341,24 @@ class CachedParamMgr(torch.nn.Module): |
|
|
|
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone() |
|
|
|
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone() |
|
|
|
self.cached_idx_map.index_fill_(0, invalid_idxs, -2) |
|
|
|
self.cached_idx_map.index_fill_(0, invalid_idxs, -2) |
|
|
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) |
|
|
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# move evict out rows to cpu |
|
|
|
|
|
|
|
if self._async_copy: |
|
|
|
|
|
|
|
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, |
|
|
|
|
|
|
|
-1).index_select(0, evict_gpu_row_idxs) |
|
|
|
|
|
|
|
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) |
|
|
|
|
|
|
|
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) |
|
|
|
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) |
|
|
|
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) |
|
|
|
|
|
|
|
|
|
|
|
elif self._evict_strategy == EvictionStrategy.LFU: |
|
|
|
elif self._evict_strategy == EvictionStrategy.LFU: |
|
|
|
backup_freqs = self.freq_cnter[invalid_idxs].clone() |
|
|
|
backup_freqs = self.freq_cnter[invalid_idxs].clone() |
|
|
|
self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize) |
|
|
|
self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize) |
|
|
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) |
|
|
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) |
|
|
|
|
|
|
|
if self._async_copy: |
|
|
|
|
|
|
|
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, |
|
|
|
|
|
|
|
-1).index_select(0, evict_gpu_row_idxs) |
|
|
|
|
|
|
|
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) |
|
|
|
|
|
|
|
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) |
|
|
|
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs) |
|
|
|
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs) |
|
|
|
|
|
|
|
|
|
|
|
evict_info = self.cached_idx_map[evict_gpu_row_idxs] |
|
|
|
evict_info = self.cached_idx_map[evict_gpu_row_idxs] |
|
|
@ -341,8 +371,13 @@ class CachedParamMgr(torch.nn.Module): |
|
|
|
tgt=self.weight.view(self.num_embeddings, -1)) |
|
|
|
tgt=self.weight.view(self.num_embeddings, -1)) |
|
|
|
else: |
|
|
|
else: |
|
|
|
# allocate tmp memory on CPU and copy rows on CUDA to CPU. |
|
|
|
# allocate tmp memory on CPU and copy rows on CUDA to CPU. |
|
|
|
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs).cpu() |
|
|
|
# TODO async gpu -> cpu |
|
|
|
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows) |
|
|
|
if self._async_copy: |
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num, |
|
|
|
|
|
|
|
-1).index_select(0, evict_gpu_row_idxs).cpu() |
|
|
|
|
|
|
|
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu) |
|
|
|
|
|
|
|
|
|
|
|
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1) |
|
|
|
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1) |
|
|
|
self.inverted_cached_idx.index_fill_(0, evict_info, -1) |
|
|
|
self.inverted_cached_idx.index_fill_(0, evict_info, -1) |
|
|
@ -359,13 +394,20 @@ class CachedParamMgr(torch.nn.Module): |
|
|
|
# Here also allocate extra memory on CUDA. #cpu_row_idxs |
|
|
|
# Here also allocate extra memory on CUDA. #cpu_row_idxs |
|
|
|
if self.buffer_size > 0: |
|
|
|
if self.buffer_size > 0: |
|
|
|
self.limit_buff_index_copyer.index_copy(0, |
|
|
|
self.limit_buff_index_copyer.index_copy(0, |
|
|
|
src_index=cpu_row_idxs.cpu(), |
|
|
|
src_index=cpu_row_idxs_copy, |
|
|
|
tgt_index=slots, |
|
|
|
tgt_index=slots, |
|
|
|
src=self.weight.view(self.num_embeddings, -1), |
|
|
|
src=self.weight.view(self.num_embeddings, -1), |
|
|
|
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) |
|
|
|
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) |
|
|
|
else: |
|
|
|
else: |
|
|
|
rows = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda() |
|
|
|
# TODO async copy cpu -> gpu |
|
|
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, rows) |
|
|
|
if self._async_copy: |
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
evict_in_rows_gpu = self.weight.view(self.num_embeddings, |
|
|
|
|
|
|
|
-1).index_select(0, cpu_row_idxs_copy).cuda() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, evict_in_rows_gpu) |
|
|
|
|
|
|
|
|
|
|
|
slot_offsets = slots |
|
|
|
slot_offsets = slots |
|
|
|
self.cached_idx_map[slots] = cpu_row_idxs |
|
|
|
self.cached_idx_map[slots] = cpu_row_idxs |
|
|
|
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets) |
|
|
|
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets) |
|
|
@ -453,4 +495,4 @@ class CachedParamMgr(torch.nn.Module): |
|
|
|
self._cuda_available_row_num -= 1 |
|
|
|
self._cuda_available_row_num -= 1 |
|
|
|
|
|
|
|
|
|
|
|
self._cpu_to_cuda_numel += self.embedding_dim |
|
|
|
self._cpu_to_cuda_numel += self.embedding_dim |
|
|
|
self._cpu_to_cuda_elpase += timer.elapsed |
|
|
|
self._cpu_to_cuda_elpase += timer.elapsed |
|
|
|