diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index 127270da0..893188b71 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -6,6 +6,7 @@ from contexttimer import Timer from .copyer import LimitBuffIndexCopyer from enum import Enum import sys +from contextlib import contextmanager class EvictionStrategy(Enum): @@ -55,7 +56,6 @@ class CachedParamMgr(torch.nn.Module): self.num_hits_history = [] self.num_miss_history = [] self.num_write_back_history = [] - self._reset_comm_stats() self._evict_strategy = evict_strategy @@ -71,6 +71,25 @@ class CachedParamMgr(torch.nn.Module): torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), dtype=torch.long).fill_(sys.maxsize), persistent=False) + self._elapsed_dict = {} + self._reset_comm_stats() + + def _reset_comm_stats(self): + for k in self._elapsed_dict.keys(): + self._elapsed_dict[k] = 0 + + self._cpu_to_cuda_numel = 0 + self._cuda_to_cpu_numel = 0 + + @contextmanager + def timer(self, name): + with Timer() as t: + yield + torch.cuda.synchronize() + + if name not in self._elapsed_dict.keys(): + self._elapsed_dict[name] = 0 + self._elapsed_dict[name] += t.elapsed def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor: """_find_evict_gpu_idxs @@ -193,7 +212,6 @@ class CachedParamMgr(torch.nn.Module): else: preload_cpu_ids = torch.arange(preload_row_num) preload_cuda_row_idxs = preload_cpu_ids.cuda() - if self.buffer_size > 0: self.limit_buff_index_copyer.index_copy(0, src_index=preload_cpu_ids, @@ -239,13 +257,20 @@ class CachedParamMgr(torch.nn.Module): def print_comm_stats(self): if self._cuda_to_cpu_numel > 0: + elapsed = self._elapsed_dict["3_2_2_evict_out_gpu_to_cpu_copy"] print( - f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / self._cuda_to_cpu_elapse} MB/s {self._cuda_to_cpu_numel / 1e6} M elem" + f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cuda_to_cpu_numel / 1e6} M elem" ) + print(f'cuda_to_cpu_elapse {elapsed} sec') if self._cpu_to_cuda_numel > 0: + elapsed = self._elapsed_dict["3_4_2_evict_in_gpu_to_cpu_copy"] print( - f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / self._cpu_to_cuda_elpase} MB/s {self._cpu_to_cuda_numel / 1e6} M elem" + f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cpu_to_cuda_numel / 1e6} M elem" ) + print(f'cpu_to_cuda_elpase {elapsed} sec') + + for k, v in self._elapsed_dict.items(): + print(f'{k}: {v}') @torch.no_grad() def _id_to_cached_cuda_id(self, ids: torch.Tensor) -> torch.Tensor: @@ -270,44 +295,45 @@ class CachedParamMgr(torch.nn.Module): Returns: torch.Tensor: indices on the cuda_cached_weight. """ - with record_function("(zhg) get unique indices"): - cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True) - - assert len(cpu_row_idxs) <= self.cuda_row_num, \ - f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \ - f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \ - f"Please increase cuda_row_num or decrease the training batch size." - self.evict_backlist = cpu_row_idxs - - with record_function("(zhg) get cpu row idxs"): - comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)] + torch.cuda.synchronize() + with self.timer("1_unique_indices") as timer: + with record_function("(cache) get unique indices"): + cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True) + + assert len(cpu_row_idxs) <= self.cuda_row_num, \ + f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \ + f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \ + f"Please increase cuda_row_num or decrease the training batch size." + self.evict_backlist = cpu_row_idxs + torch.cuda.synchronize() + + # O(cache ratio) + with self.timer("2_cpu_row_idx") as timer: + with record_function("(cache) get cpu row idxs"): + comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)] self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs)) self.num_miss_history.append(len(comm_cpu_row_idxs)) self.num_write_back_history.append(0) # move sure the cuda rows will not be evicted! - with record_function("(zhg) cache update"): - self._prepare_rows_on_cuda(comm_cpu_row_idxs) + with self.timer("3_prepare_rows_on_cuda") as timer: + with record_function("(cache) prepare_rows_on_cuda"): + self._prepare_rows_on_cuda(comm_cpu_row_idxs) self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype) - with record_function("(zhg) embed cpu rows idx -> cache gpu row idxs"): - gpu_row_idxs = self._id_to_cached_cuda_id(ids) + with self.timer("4_cpu_to_gpu_row_idxs") as timer: + with record_function("(cache) embed cpu rows idx -> cache gpu row idxs"): + gpu_row_idxs = self._id_to_cached_cuda_id(ids) - # update for LFU. - if self._evict_strategy == EvictionStrategy.LFU: - unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs] - self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times) + # update for LFU. + if self._evict_strategy == EvictionStrategy.LFU: + unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs] + self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times) return gpu_row_idxs - def _reset_comm_stats(self): - self._cpu_to_cuda_numel = 0 - self._cpu_to_cuda_elpase = 0 - self._cuda_to_cpu_elapse = 0 - self._cuda_to_cpu_numel = 0 - def _row_in_cuda(self, row_id: int) -> bool: return self.inverted_cached_idx[row_id] != -1 @@ -324,14 +350,17 @@ class CachedParamMgr(torch.nn.Module): # move evict in rows to gpu if self._async_copy: if self.buffer_size == 0: - rows_cpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory() - evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device()) - evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True) + idxslt_stream = torch.cuda.Stream() + with torch.cuda.stream(idxslt_stream): + rows_cpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory() + # evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device()) + # evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True) else: raise NotImplemented if evict_num > 0: - with Timer() as timer: + torch.cuda.synchronize() + with self.timer("3_1_evict_prepare") as timer: mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) if self._evict_strategy == EvictionStrategy.DATASET: @@ -340,7 +369,9 @@ class CachedParamMgr(torch.nn.Module): # so those idxs will be sorted to end, therefore not being chosen as victim backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone() self.cached_idx_map.index_fill_(0, invalid_idxs, -2) - evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) + + with self.timer("3_1_1_find_evict_gpu_idxs_elapsed") as timer: + evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) # move evict out rows to cpu if self._async_copy: @@ -353,7 +384,10 @@ class CachedParamMgr(torch.nn.Module): elif self._evict_strategy == EvictionStrategy.LFU: backup_freqs = self.freq_cnter[invalid_idxs].clone() self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize) - evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) + + with self.timer("3_1_1_find_evict_gpu_idxs_elapsed") as timer: + evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) + if self._async_copy: evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs) @@ -363,6 +397,7 @@ class CachedParamMgr(torch.nn.Module): evict_info = self.cached_idx_map[evict_gpu_row_idxs] + with self.timer("3_2_evict_out_elapse") as timer: if self.buffer_size > 0: self.limit_buff_index_copyer.index_copy(0, src_index=evict_gpu_row_idxs, @@ -375,8 +410,12 @@ class CachedParamMgr(torch.nn.Module): if self._async_copy: pass else: - evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num, - -1).index_select(0, evict_gpu_row_idxs).cpu() + with self.timer("3_2_1_evict_out_index_select") as timer: + evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num, + -1).index_select(0, evict_gpu_row_idxs) + with self.timer("3_2_2_evict_out_gpu_to_cpu_copy") as timer: + evict_out_rows_cpu = evict_out_rows_cpu.cpu() + self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu) self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1) @@ -385,12 +424,15 @@ class CachedParamMgr(torch.nn.Module): self._cuda_available_row_num += evict_num weight_size = evict_gpu_row_idxs.numel() * self.embedding_dim - self._cuda_to_cpu_elapse += timer.elapsed self._cuda_to_cpu_numel += weight_size # print(f"evict embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB") - with Timer() as timer: + # slots of cuda weight to evict in + with self.timer("3_3_non_zero") as timer: slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[:cpu_row_idxs.numel()] + + # TODO wait for optimize + with self.timer("3_4_evict_in_elapse") as timer: # Here also allocate extra memory on CUDA. #cpu_row_idxs if self.buffer_size > 0: self.limit_buff_index_copyer.index_copy(0, @@ -399,22 +441,34 @@ class CachedParamMgr(torch.nn.Module): src=self.weight.view(self.num_embeddings, -1), tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) else: - # TODO async copy cpu -> gpu if self._async_copy: + torch.cuda.current_stream().wait_stream(idxslt_stream) + evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device()) + evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True) pass else: - evict_in_rows_gpu = self.weight.view(self.num_embeddings, - -1).index_select(0, cpu_row_idxs_copy).cuda() + # TODO hotspot: index select copy cpu -> gpu, cpu index? - self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, evict_in_rows_gpu) + with self.timer("3_4_1_evict_in_index_select") as timer: + # narrow index select to a subset of self.weight + # tmp = torch.narrow(self.weight.view(self.num_embeddings, -1), 0, min(cpu_row_idxs).cpu(), max(cpu_row_idxs) - min(cpu_row_idxs) + 1) + # evict_in_rows_gpu = tmp.index_select(0, cpu_row_idxs_copy - min(cpu_row_idxs).cpu()) + evict_in_rows_gpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy) + with self.timer("3_4_2_evict_in_gpu_to_cpu_copy") as timer: + evict_in_rows_gpu = evict_in_rows_gpu.cuda() + + with self.timer("3_4_3_evict_in_index_copy") as timer: + self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, evict_in_rows_gpu) + + with self.timer("3_4_evict_in_elapse") as timer: slot_offsets = slots self.cached_idx_map[slots] = cpu_row_idxs self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets) if self._evict_strategy == EvictionStrategy.LFU: self.freq_cnter.index_fill_(0, slots, 0) self._cuda_available_row_num -= cpu_row_idxs.numel() - self._cpu_to_cuda_elpase += timer.elapsed + weight_size = cpu_row_idxs.numel() * self.embedding_dim self._cpu_to_cuda_numel += weight_size # print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB") @@ -461,7 +515,6 @@ class CachedParamMgr(torch.nn.Module): self._cuda_available_row_num += 1 self._cuda_to_cpu_numel += self.embedding_dim - self._cuda_to_cpu_elapse += timer.elapsed # self.num_write_back_history[-1] += 1 return max_cpu_row_idx @@ -495,4 +548,3 @@ class CachedParamMgr(torch.nn.Module): self._cuda_available_row_num -= 1 self._cpu_to_cuda_numel += self.embedding_dim - self._cpu_to_cuda_elpase += timer.elapsed