|
|
|
@ -256,13 +256,13 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
assert torch.all(self.cached_idx_map == -1).item()
|
|
|
|
|
|
|
|
|
|
def print_comm_stats(self):
|
|
|
|
|
if self._cuda_to_cpu_numel > 0:
|
|
|
|
|
if self._cuda_to_cpu_numel > 0 and "3_2_2_evict_out_gpu_to_cpu_copy" in self._elapsed_dict:
|
|
|
|
|
elapsed = self._elapsed_dict["3_2_2_evict_out_gpu_to_cpu_copy"]
|
|
|
|
|
print(
|
|
|
|
|
f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cuda_to_cpu_numel / 1e6} M elem"
|
|
|
|
|
)
|
|
|
|
|
print(f'cuda_to_cpu_elapse {elapsed} sec')
|
|
|
|
|
if self._cpu_to_cuda_numel > 0:
|
|
|
|
|
if self._cpu_to_cuda_numel > 0 and "3_4_2_evict_in_gpu_to_cpu_copy" in self._elapsed_dict:
|
|
|
|
|
elapsed = self._elapsed_dict["3_4_2_evict_in_gpu_to_cpu_copy"]
|
|
|
|
|
print(
|
|
|
|
|
f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cpu_to_cuda_numel / 1e6} M elem"
|
|
|
|
@ -382,6 +382,7 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
|
|
|
|
|
|
|
|
|
elif self._evict_strategy == EvictionStrategy.LFU:
|
|
|
|
|
with self.timer("3_1_0_backup_freqs") as timer:
|
|
|
|
|
backup_freqs = self.freq_cnter[invalid_idxs].clone()
|
|
|
|
|
self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize)
|
|
|
|
|
|
|
|
|
@ -393,6 +394,7 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
-1).index_select(0, evict_gpu_row_idxs)
|
|
|
|
|
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
|
|
|
|
|
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
|
|
|
|
with self.timer("3_1_2_find_evict_index_copy") as timer:
|
|
|
|
|
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs)
|
|
|
|
|
|
|
|
|
|
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
|
|
|
@ -416,6 +418,7 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
with self.timer("3_2_2_evict_out_gpu_to_cpu_copy") as timer:
|
|
|
|
|
evict_out_rows_cpu = evict_out_rows_cpu.cpu()
|
|
|
|
|
|
|
|
|
|
with self.timer("3_2_2_evict_out_index_select") as timer:
|
|
|
|
|
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu)
|
|
|
|
|
|
|
|
|
|
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
|
|
|
|
@ -447,13 +450,12 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True)
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
# TODO hotspot: index select copy cpu -> gpu, cpu index?
|
|
|
|
|
|
|
|
|
|
with self.timer("3_4_1_evict_in_index_select") as timer:
|
|
|
|
|
# narrow index select to a subset of self.weight
|
|
|
|
|
# tmp = torch.narrow(self.weight.view(self.num_embeddings, -1), 0, min(cpu_row_idxs).cpu(), max(cpu_row_idxs) - min(cpu_row_idxs) + 1)
|
|
|
|
|
# evict_in_rows_gpu = tmp.index_select(0, cpu_row_idxs_copy - min(cpu_row_idxs).cpu())
|
|
|
|
|
evict_in_rows_gpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy)
|
|
|
|
|
evict_in_rows_gpu = self.weight.view(self.num_embeddings,
|
|
|
|
|
-1).index_select(0, cpu_row_idxs_copy).pin_memory()
|
|
|
|
|
|
|
|
|
|
with self.timer("3_4_2_evict_in_gpu_to_cpu_copy") as timer:
|
|
|
|
|
evict_in_rows_gpu = evict_in_rows_gpu.cuda()
|
|
|
|
@ -461,10 +463,9 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
with self.timer("3_4_3_evict_in_index_copy") as timer:
|
|
|
|
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, evict_in_rows_gpu)
|
|
|
|
|
|
|
|
|
|
with self.timer("3_4_evict_in_elapse") as timer:
|
|
|
|
|
slot_offsets = slots
|
|
|
|
|
with self.timer("3_5_evict_in_elapse_final") as timer:
|
|
|
|
|
self.cached_idx_map[slots] = cpu_row_idxs
|
|
|
|
|
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets)
|
|
|
|
|
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slots)
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU:
|
|
|
|
|
self.freq_cnter.index_fill_(0, slots, 0)
|
|
|
|
|
self._cuda_available_row_num -= cpu_row_idxs.numel()
|
|
|
|
|