Browse Source

[embedding] add more detail profiling (#1656)

pull/1657/head
Jiarui Fang 2 years ago committed by GitHub
parent
commit
988570e4a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 25
      colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py

25
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py

@ -256,13 +256,13 @@ class CachedParamMgr(torch.nn.Module):
assert torch.all(self.cached_idx_map == -1).item()
def print_comm_stats(self):
if self._cuda_to_cpu_numel > 0:
if self._cuda_to_cpu_numel > 0 and "3_2_2_evict_out_gpu_to_cpu_copy" in self._elapsed_dict:
elapsed = self._elapsed_dict["3_2_2_evict_out_gpu_to_cpu_copy"]
print(
f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cuda_to_cpu_numel / 1e6} M elem"
)
print(f'cuda_to_cpu_elapse {elapsed} sec')
if self._cpu_to_cuda_numel > 0:
if self._cpu_to_cuda_numel > 0 and "3_4_2_evict_in_gpu_to_cpu_copy" in self._elapsed_dict:
elapsed = self._elapsed_dict["3_4_2_evict_in_gpu_to_cpu_copy"]
print(
f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cpu_to_cuda_numel / 1e6} M elem"
@ -382,8 +382,9 @@ class CachedParamMgr(torch.nn.Module):
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
elif self._evict_strategy == EvictionStrategy.LFU:
backup_freqs = self.freq_cnter[invalid_idxs].clone()
self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize)
with self.timer("3_1_0_backup_freqs") as timer:
backup_freqs = self.freq_cnter[invalid_idxs].clone()
self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize)
with self.timer("3_1_1_find_evict_gpu_idxs_elapsed") as timer:
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
@ -393,7 +394,8 @@ class CachedParamMgr(torch.nn.Module):
-1).index_select(0, evict_gpu_row_idxs)
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs)
with self.timer("3_1_2_find_evict_index_copy") as timer:
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs)
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
@ -416,7 +418,8 @@ class CachedParamMgr(torch.nn.Module):
with self.timer("3_2_2_evict_out_gpu_to_cpu_copy") as timer:
evict_out_rows_cpu = evict_out_rows_cpu.cpu()
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu)
with self.timer("3_2_2_evict_out_index_select") as timer:
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu)
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
self.inverted_cached_idx.index_fill_(0, evict_info, -1)
@ -447,13 +450,12 @@ class CachedParamMgr(torch.nn.Module):
evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True)
pass
else:
# TODO hotspot: index select copy cpu -> gpu, cpu index?
with self.timer("3_4_1_evict_in_index_select") as timer:
# narrow index select to a subset of self.weight
# tmp = torch.narrow(self.weight.view(self.num_embeddings, -1), 0, min(cpu_row_idxs).cpu(), max(cpu_row_idxs) - min(cpu_row_idxs) + 1)
# evict_in_rows_gpu = tmp.index_select(0, cpu_row_idxs_copy - min(cpu_row_idxs).cpu())
evict_in_rows_gpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy)
evict_in_rows_gpu = self.weight.view(self.num_embeddings,
-1).index_select(0, cpu_row_idxs_copy).pin_memory()
with self.timer("3_4_2_evict_in_gpu_to_cpu_copy") as timer:
evict_in_rows_gpu = evict_in_rows_gpu.cuda()
@ -461,10 +463,9 @@ class CachedParamMgr(torch.nn.Module):
with self.timer("3_4_3_evict_in_index_copy") as timer:
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, evict_in_rows_gpu)
with self.timer("3_4_evict_in_elapse") as timer:
slot_offsets = slots
with self.timer("3_5_evict_in_elapse_final") as timer:
self.cached_idx_map[slots] = cpu_row_idxs
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets)
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slots)
if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter.index_fill_(0, slots, 0)
self._cuda_available_row_num -= cpu_row_idxs.numel()

Loading…
Cancel
Save