From 38c68b5b9adff9e30b0ac487b79d1b0b9f0c6fd0 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 22 Sep 2022 11:16:25 +0800 Subject: [PATCH] [embedding] rollback for better FAW performance (#1625) --- .../layers/cache_embedding/cache_mgr.py | 140 +++++++----------- tests/test_layers/test_cache_embedding.py | 2 +- 2 files changed, 52 insertions(+), 90 deletions(-) diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index ef20cfc79..6f591ad44 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -20,15 +20,15 @@ class CachedParamMgr(torch.nn.Module): CPU maintains the entire original weight. CUDA maintains a fraction of the weights used in the upcoming computation. The row number in CUDA is controlled by `cuda_row_num`. During training, GPU needs to transmit embedding rows between CPU and GPU. - Args: weight (torch.Tensor): the weight of the Embedding layer. cuda_row_num (int, optional): the number of rows cached in CUDA memory. Defaults to 0. buffer_size (int, optional): the number of rows in a data transmitter buffer. Defaults to 50_000. - pin_weight (bool, optional): use pin memory to store the cpu weight. If set `True`, the cpu memory usage will increase largely. Defaults to False. - evict_strategy (EvictionStrategy, optional): the eviction strategy. There are two options. `EvictionStrategy.LFU` uses the least frequently used cache. `EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume. - Default as EvictionStrategy.DATASET. - use_cpu_caching (bool, optional): use cpu to execute cache indexing. It is slower than use gpu. + pin_weight (bool, optional): use pin memory to store the cpu weight. If set `True`, the cpu memory usage will increase largely. Defaults to False. + evict_strategy (EvictionStrategy, optional): the eviction strategy. There are two options. + `EvictionStrategy.LFU`: use the least frequently used cache. + `EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume. + Defaults to EvictionStrategy.DATASET. """ def __init__( @@ -38,7 +38,6 @@ class CachedParamMgr(torch.nn.Module): buffer_size: int = 0, pin_weight: bool = False, evict_strategy: EvictionStrategy = EvictionStrategy.DATASET, - use_cpu_caching=False, ) -> None: super(CachedParamMgr, self).__init__() self.buffer_size = buffer_size @@ -48,13 +47,6 @@ class CachedParamMgr(torch.nn.Module): self.pin_weight = pin_weight self.elem_size_in_byte = weight.element_size() - self._cpu_caching = use_cpu_caching - - if self._cpu_caching: - self._cache_dev = torch.device('cpu') - else: - self._cache_dev = torch.cuda.current_device() - # weight configure self._init_weight(weight) @@ -69,24 +61,16 @@ class CachedParamMgr(torch.nn.Module): if self._evict_strategy == EvictionStrategy.LFU: # cache_row_idx -> frequency, freq of the cache rows. # classic lfu cache. evict the minimal freq value row in cuda cache. - if self._cpu_caching: - self.freq_cnter = torch.empty(self.cuda_row_num, device=self._cache_dev, - dtype=torch.long).fill_(sys.maxsize) - - else: - self.register_buffer("freq_cnter", - torch.empty(self.cuda_row_num, device=self._cache_dev, - dtype=torch.long).fill_(sys.maxsize), - persistent=False) + self.register_buffer("freq_cnter", + torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), + dtype=torch.long).fill_(sys.maxsize), + persistent=False) def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor: """_find_evict_gpu_idxs - Find the gpu idxs to be evicted, according to their freq. - Args: evict_num (int): how many rows has to be evicted - Returns: torch.Tensor: a list tensor (1D), contains the gpu_row_idxs. """ @@ -117,32 +101,26 @@ class CachedParamMgr(torch.nn.Module): self.weight = weight.pin_memory() if self.pin_weight else weight # map original id to new id with respect to frequency # id -> cpu_row_idx + self.register_buffer( + "idx_map", + torch.arange(self.num_embeddings, dtype=torch.long, device=torch.cuda.current_device()), + persistent=False, + ) - if self._cpu_caching: - self.idx_map = torch.arange(self.num_embeddings, dtype=torch.long, device=self._cache_dev) - self.cached_idx_map = torch.empty(self.cuda_row_num, device=self._cache_dev, dtype=torch.long).fill_(-1) - self.inverted_cached_idx = torch.zeros(self.num_embeddings, device=self._cache_dev, - dtype=torch.long).fill_(-1) - else: - self.register_buffer( - "idx_map", - torch.arange(self.num_embeddings, dtype=torch.long, device=self._cache_dev), - persistent=False, - ) - - # cached_idx_map: gpu_row_idx -> cpu_row_idx - self.register_buffer("cached_idx_map", - torch.empty(self.cuda_row_num, device=self._cache_dev, dtype=torch.long).fill_(-1), - persistent=False) - - # cpu_row_id -> gpu_row_idx. - # gpu_row_idx as -1 means cpu_row_id not in CUDA. - self.register_buffer("inverted_cached_idx", - torch.zeros(self.num_embeddings, device=self._cache_dev, - dtype=torch.long).fill_(-1), - persistent=False) - - self.evict_backlist = torch.tensor([], device=self._cache_dev) + # cached_idx_map: gpu_row_idx -> cpu_row_idx + self.register_buffer("cached_idx_map", + torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), + dtype=torch.long).fill_(-1), + persistent=False) + + # cpu_row_id -> gpu_row_idx. + # gpu_row_idx as -1 means cpu_row_id not in CUDA. + self.register_buffer("inverted_cached_idx", + torch.zeros(self.num_embeddings, device=torch.cuda.current_device(), + dtype=torch.long).fill_(-1), + persistent=False) + + self.evict_backlist = torch.tensor([], device=torch.cuda.current_device()) # index copy buffer size should less than 10% of cuda weight. if self.buffer_size > 0: @@ -157,10 +135,8 @@ class CachedParamMgr(torch.nn.Module): def cpu_weight_data(self, row_idx: int) -> torch.Tensor: """ access a row of CPU weight. - Args: row_idx (int): the idx of rows - Returns: torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D. """ @@ -178,14 +154,12 @@ class CachedParamMgr(torch.nn.Module): """reorder reorder the weight according to ids' frequency in dataset before training. Execute only once before training, also known as warmup phase. - + Note: If you would like to use the DATASET as the eviction strategy, you must call this function. - Note: If you are use the LFU as the eviction strategy, you can skip this function. If you still use this function. It will initialize The frequency in LFU cache using the dataset statistics. - Args: ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight. warmup_ratio (float): the amount of chunks preloaded in cuda cache @@ -209,24 +183,24 @@ class CachedParamMgr(torch.nn.Module): # extract rows from cpu weight if self._evict_strategy == EvictionStrategy.LFU and ids_freq_mapping is not None: freq_value, preload_cpu_ids = torch.topk(ids_freq_mapping, preload_row_num, dim=0, largest=True) - preload_cuda_row_idxs = torch.arange(preload_row_num).to(self._cache_dev) + preload_cuda_row_idxs = torch.arange(preload_row_num).cuda() else: - preload_cpu_ids = torch.arange(preload_row_num, device=self.weight.device) - preload_cuda_row_idxs = preload_cpu_ids.to(self._cache_dev) + preload_cpu_ids = torch.arange(preload_row_num) + preload_cuda_row_idxs = preload_cpu_ids.cuda() if self.buffer_size > 0: self.limit_buff_index_copyer.index_copy(0, src_index=preload_cpu_ids, - tgt_index=preload_cuda_row_idxs.cuda(), + tgt_index=preload_cuda_row_idxs, src=self.weight.view(self.num_embeddings, -1), tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) else: preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_cpu_ids).cuda() - self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs.cuda(), + self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs, preload_rows) # update auxiliary info - self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.to(self._cache_dev) + self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.cuda() self.inverted_cached_idx[preload_cpu_ids] = preload_cuda_row_idxs self._cuda_available_row_num -= preload_row_num @@ -235,7 +209,7 @@ class CachedParamMgr(torch.nn.Module): if ids_freq_mapping is None: self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0) else: - self.freq_cnter[preload_cuda_row_idxs] = freq_value.to(self._cache_dev) + self.freq_cnter[preload_cuda_row_idxs] = freq_value.cuda() print(f'Cache warmup finished cost {timer.elapsed} sec.') @@ -245,7 +219,7 @@ class CachedParamMgr(torch.nn.Module): """ slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1) row_ids = self.cached_idx_map[slots] - rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots.cuda()).cpu() + rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu() self.weight.view(self.num_embeddings, -1).index_copy_(0, row_ids.cpu(), rows) self.cached_idx_map.index_fill_(0, slots, -1) self.inverted_cached_idx.index_fill_(0, row_ids, -1) @@ -272,10 +246,8 @@ class CachedParamMgr(torch.nn.Module): """ convert ids to indices in self.cuda_cached_weight. Implemented with parallel operations on GPU. - Args: ids (torch.Tensor): ids from the dataset - Returns: torch.Tensor: contains indices in self.cuda_cached_weight """ @@ -287,14 +259,12 @@ class CachedParamMgr(torch.nn.Module): def prepare_ids(self, ids: torch.Tensor) -> torch.Tensor: """ move the cpu embedding rows w.r.t. ids into CUDA memory - Args: ids (torch.Tensor): the ids to be computed Returns: torch.Tensor: indices on the cuda_cached_weight. """ - with record_function(f"(pre-id) get unique indices. cache ratio {self.cuda_row_num / self.num_embeddings}"): - ids = ids.to(self._cache_dev) + with record_function("(zhg) get unique indices"): cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True) assert len(cpu_row_idxs) <= self.cuda_row_num, \ @@ -303,29 +273,26 @@ class CachedParamMgr(torch.nn.Module): f"Please increase cuda_row_num or decrease the training batch size." self.evict_backlist = cpu_row_idxs - with record_function("(pre-id) get cpu row idxs"): - comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, - self.cached_idx_map, - assume_unique=True, - invert=True)] + with record_function("(zhg) get cpu row idxs"): + comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)] self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs)) self.num_miss_history.append(len(comm_cpu_row_idxs)) self.num_write_back_history.append(0) # move sure the cuda rows will not be evicted! - with record_function("(pre-id) cache update"): + with record_function("(zhg) cache update"): self._prepare_rows_on_cuda(comm_cpu_row_idxs) - self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype) - with record_function("(pre-id) embed cpu rows idx -> cache gpu row idxs"): + self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype) + + with record_function("(zhg) embed cpu rows idx -> cache gpu row idxs"): gpu_row_idxs = self._id_to_cached_cuda_id(ids) # update for LFU. if self._evict_strategy == EvictionStrategy.LFU: - with record_function("(pre-id) lfu cnter updates"): - unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs] - self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times) + unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs] + self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times) return gpu_row_idxs @@ -341,14 +308,13 @@ class CachedParamMgr(torch.nn.Module): @torch.no_grad() def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: """prepare rows in cpu_row_idxs on CUDA memory - Args: cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA """ evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num if evict_num > 0: with Timer() as timer: - mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist, assume_unique=True) + mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) if self._evict_strategy == EvictionStrategy.DATASET: # mask method. @@ -375,8 +341,7 @@ class CachedParamMgr(torch.nn.Module): tgt=self.weight.view(self.num_embeddings, -1)) else: # allocate tmp memory on CPU and copy rows on CUDA to CPU. - rows = self.cuda_cached_weight.view(self.cuda_row_num, - -1).index_select(0, evict_gpu_row_idxs.cuda()).cpu() + rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs).cpu() self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows) self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1) @@ -395,12 +360,12 @@ class CachedParamMgr(torch.nn.Module): if self.buffer_size > 0: self.limit_buff_index_copyer.index_copy(0, src_index=cpu_row_idxs.cpu(), - tgt_index=slots.cuda(), + tgt_index=slots, src=self.weight.view(self.num_embeddings, -1), tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) else: rows = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda() - self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots.cuda(), rows) + self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, rows) slot_offsets = slots self.cached_idx_map[slots] = cpu_row_idxs self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets) @@ -421,7 +386,6 @@ class CachedParamMgr(torch.nn.Module): def _evict(self) -> int: """ deprecated - evict one row from cuda to cpu. Returns: (int) : the slot id be evicted. @@ -463,9 +427,7 @@ class CachedParamMgr(torch.nn.Module): def _admit(self, row_id: int): """ deprecated - move in row_id to CUDA - Args: row_id (int): the id of row to be moved in """ @@ -491,4 +453,4 @@ class CachedParamMgr(torch.nn.Module): self._cuda_available_row_num -= 1 self._cpu_to_cuda_numel += self.embedding_dim - self._cpu_to_cuda_elpase += timer.elapsed + self._cpu_to_cuda_elpase += timer.elapsed \ No newline at end of file diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py index 5bb654217..039301a7e 100644 --- a/tests/test_layers/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -90,7 +90,7 @@ def test_reorder_with_freq(): offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=dev) weight = torch.rand(num_embed, 2) - mgr = CachedParamMgr(weight, num_chunk, use_cpu_caching=dev.type == 'cpu') + mgr = CachedParamMgr(weight, num_chunk) mgr.reorder(idx_map)