diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index 893188b71..d89290145 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -256,13 +256,13 @@ class CachedParamMgr(torch.nn.Module): assert torch.all(self.cached_idx_map == -1).item() def print_comm_stats(self): - if self._cuda_to_cpu_numel > 0: + if self._cuda_to_cpu_numel > 0 and "3_2_2_evict_out_gpu_to_cpu_copy" in self._elapsed_dict: elapsed = self._elapsed_dict["3_2_2_evict_out_gpu_to_cpu_copy"] print( f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cuda_to_cpu_numel / 1e6} M elem" ) print(f'cuda_to_cpu_elapse {elapsed} sec') - if self._cpu_to_cuda_numel > 0: + if self._cpu_to_cuda_numel > 0 and "3_4_2_evict_in_gpu_to_cpu_copy" in self._elapsed_dict: elapsed = self._elapsed_dict["3_4_2_evict_in_gpu_to_cpu_copy"] print( f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cpu_to_cuda_numel / 1e6} M elem" @@ -382,8 +382,9 @@ class CachedParamMgr(torch.nn.Module): self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) elif self._evict_strategy == EvictionStrategy.LFU: - backup_freqs = self.freq_cnter[invalid_idxs].clone() - self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize) + with self.timer("3_1_0_backup_freqs") as timer: + backup_freqs = self.freq_cnter[invalid_idxs].clone() + self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize) with self.timer("3_1_1_find_evict_gpu_idxs_elapsed") as timer: evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) @@ -393,7 +394,8 @@ class CachedParamMgr(torch.nn.Module): -1).index_select(0, evict_gpu_row_idxs) evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) - self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs) + with self.timer("3_1_2_find_evict_index_copy") as timer: + self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs) evict_info = self.cached_idx_map[evict_gpu_row_idxs] @@ -416,7 +418,8 @@ class CachedParamMgr(torch.nn.Module): with self.timer("3_2_2_evict_out_gpu_to_cpu_copy") as timer: evict_out_rows_cpu = evict_out_rows_cpu.cpu() - self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu) + with self.timer("3_2_2_evict_out_index_select") as timer: + self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu) self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1) self.inverted_cached_idx.index_fill_(0, evict_info, -1) @@ -447,13 +450,12 @@ class CachedParamMgr(torch.nn.Module): evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True) pass else: - # TODO hotspot: index select copy cpu -> gpu, cpu index? - with self.timer("3_4_1_evict_in_index_select") as timer: # narrow index select to a subset of self.weight # tmp = torch.narrow(self.weight.view(self.num_embeddings, -1), 0, min(cpu_row_idxs).cpu(), max(cpu_row_idxs) - min(cpu_row_idxs) + 1) # evict_in_rows_gpu = tmp.index_select(0, cpu_row_idxs_copy - min(cpu_row_idxs).cpu()) - evict_in_rows_gpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy) + evict_in_rows_gpu = self.weight.view(self.num_embeddings, + -1).index_select(0, cpu_row_idxs_copy).pin_memory() with self.timer("3_4_2_evict_in_gpu_to_cpu_copy") as timer: evict_in_rows_gpu = evict_in_rows_gpu.cuda() @@ -461,10 +463,9 @@ class CachedParamMgr(torch.nn.Module): with self.timer("3_4_3_evict_in_index_copy") as timer: self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, evict_in_rows_gpu) - with self.timer("3_4_evict_in_elapse") as timer: - slot_offsets = slots + with self.timer("3_5_evict_in_elapse_final") as timer: self.cached_idx_map[slots] = cpu_row_idxs - self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets) + self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slots) if self._evict_strategy == EvictionStrategy.LFU: self.freq_cnter.index_fill_(0, slots, 0) self._cuda_available_row_num -= cpu_row_idxs.numel()