|
|
|
@ -256,13 +256,13 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
assert torch.all(self.cached_idx_map == -1).item() |
|
|
|
|
|
|
|
|
|
def print_comm_stats(self): |
|
|
|
|
if self._cuda_to_cpu_numel > 0: |
|
|
|
|
if self._cuda_to_cpu_numel > 0 and "3_2_2_evict_out_gpu_to_cpu_copy" in self._elapsed_dict: |
|
|
|
|
elapsed = self._elapsed_dict["3_2_2_evict_out_gpu_to_cpu_copy"] |
|
|
|
|
print( |
|
|
|
|
f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cuda_to_cpu_numel / 1e6} M elem" |
|
|
|
|
) |
|
|
|
|
print(f'cuda_to_cpu_elapse {elapsed} sec') |
|
|
|
|
if self._cpu_to_cuda_numel > 0: |
|
|
|
|
if self._cpu_to_cuda_numel > 0 and "3_4_2_evict_in_gpu_to_cpu_copy" in self._elapsed_dict: |
|
|
|
|
elapsed = self._elapsed_dict["3_4_2_evict_in_gpu_to_cpu_copy"] |
|
|
|
|
print( |
|
|
|
|
f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cpu_to_cuda_numel / 1e6} M elem" |
|
|
|
@ -382,8 +382,9 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) |
|
|
|
|
|
|
|
|
|
elif self._evict_strategy == EvictionStrategy.LFU: |
|
|
|
|
backup_freqs = self.freq_cnter[invalid_idxs].clone() |
|
|
|
|
self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize) |
|
|
|
|
with self.timer("3_1_0_backup_freqs") as timer: |
|
|
|
|
backup_freqs = self.freq_cnter[invalid_idxs].clone() |
|
|
|
|
self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize) |
|
|
|
|
|
|
|
|
|
with self.timer("3_1_1_find_evict_gpu_idxs_elapsed") as timer: |
|
|
|
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) |
|
|
|
@ -393,7 +394,8 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
-1).index_select(0, evict_gpu_row_idxs) |
|
|
|
|
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) |
|
|
|
|
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) |
|
|
|
|
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs) |
|
|
|
|
with self.timer("3_1_2_find_evict_index_copy") as timer: |
|
|
|
|
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs) |
|
|
|
|
|
|
|
|
|
evict_info = self.cached_idx_map[evict_gpu_row_idxs] |
|
|
|
|
|
|
|
|
@ -416,7 +418,8 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
with self.timer("3_2_2_evict_out_gpu_to_cpu_copy") as timer: |
|
|
|
|
evict_out_rows_cpu = evict_out_rows_cpu.cpu() |
|
|
|
|
|
|
|
|
|
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu) |
|
|
|
|
with self.timer("3_2_2_evict_out_index_select") as timer: |
|
|
|
|
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu) |
|
|
|
|
|
|
|
|
|
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1) |
|
|
|
|
self.inverted_cached_idx.index_fill_(0, evict_info, -1) |
|
|
|
@ -447,13 +450,12 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True) |
|
|
|
|
pass |
|
|
|
|
else: |
|
|
|
|
# TODO hotspot: index select copy cpu -> gpu, cpu index? |
|
|
|
|
|
|
|
|
|
with self.timer("3_4_1_evict_in_index_select") as timer: |
|
|
|
|
# narrow index select to a subset of self.weight |
|
|
|
|
# tmp = torch.narrow(self.weight.view(self.num_embeddings, -1), 0, min(cpu_row_idxs).cpu(), max(cpu_row_idxs) - min(cpu_row_idxs) + 1) |
|
|
|
|
# evict_in_rows_gpu = tmp.index_select(0, cpu_row_idxs_copy - min(cpu_row_idxs).cpu()) |
|
|
|
|
evict_in_rows_gpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy) |
|
|
|
|
evict_in_rows_gpu = self.weight.view(self.num_embeddings, |
|
|
|
|
-1).index_select(0, cpu_row_idxs_copy).pin_memory() |
|
|
|
|
|
|
|
|
|
with self.timer("3_4_2_evict_in_gpu_to_cpu_copy") as timer: |
|
|
|
|
evict_in_rows_gpu = evict_in_rows_gpu.cuda() |
|
|
|
@ -461,10 +463,9 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
with self.timer("3_4_3_evict_in_index_copy") as timer: |
|
|
|
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, evict_in_rows_gpu) |
|
|
|
|
|
|
|
|
|
with self.timer("3_4_evict_in_elapse") as timer: |
|
|
|
|
slot_offsets = slots |
|
|
|
|
with self.timer("3_5_evict_in_elapse_final") as timer: |
|
|
|
|
self.cached_idx_map[slots] = cpu_row_idxs |
|
|
|
|
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets) |
|
|
|
|
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slots) |
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU: |
|
|
|
|
self.freq_cnter.index_fill_(0, slots, 0) |
|
|
|
|
self._cuda_available_row_num -= cpu_row_idxs.numel() |
|
|
|
|