From a389ac4ec93aa9596417fe0b8ec74e7f88c8882c Mon Sep 17 00:00:00 2001 From: CsRic <59389055+CsRic@users.noreply.github.com> Date: Thu, 8 Sep 2022 16:41:19 +0800 Subject: [PATCH] [embedding] cache_embedding small improvement (#1564) --- .../layers/cache_embedding/cache_mgr.py | 7 ++-- .../parallel_freq_aware_embedding.py | 5 ++- ...parallel_freq_aware_embedding_tablewise.py | 36 +++++++++++++++---- 3 files changed, 35 insertions(+), 13 deletions(-) diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index fdb120134..e7daf5355 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -178,7 +178,7 @@ class CachedParamMgr(torch.nn.Module): """reorder reorder the weight according to ids' frequency in dataset before training. Execute only once before training, also known as warmup phase. - + Note: If you would like to use the DATASET as the eviction strategy, you must call this function. @@ -304,7 +304,8 @@ class CachedParamMgr(torch.nn.Module): self.evict_backlist = cpu_row_idxs with record_function("(pre-id) get cpu row idxs"): - comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)] + comm_cpu_row_idxs = cpu_row_idxs[torch.isin( + cpu_row_idxs, self.cached_idx_map, assume_unique=True, invert=True)] self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs)) self.num_miss_history.append(len(comm_cpu_row_idxs)) @@ -345,7 +346,7 @@ class CachedParamMgr(torch.nn.Module): evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num if evict_num > 0: with Timer() as timer: - mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) + mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist, assume_unique=True) invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) if self._evict_strategy == EvictionStrategy.DATASET: # mask method. diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py index e53b126b7..61d870fad 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py @@ -75,7 +75,6 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1): with torch.no_grad(): reorder_ids = self.cache_weight_mgr.prepare_ids(indices) - output_shard = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, per_sample_weights, self.include_last_offset, self.padding_idx) @@ -124,6 +123,6 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): def print_comm_stats_(self): self.cache_weight_mgr.print_comm_stats() - + def element_size(self): - return self.weight.element_size() \ No newline at end of file + return self.weight.element_size() diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py index c0d72fbfc..d2f6b7c53 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py @@ -87,6 +87,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): local_per_sample_weights_list: List(torch.Tensor) = [] offset_pre_end = 0 # local_offsets trick + for i, handle_table in enumerate(self.assigned_table_list): indices_start_position = offsets[batch_size * handle_table] if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): @@ -94,6 +95,28 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): indices_end_position = indices.shape[0] else: indices_end_position = offsets[batch_size * (handle_table + 1)] + # alternative approach: reduce malloc + ''' + # 1. local_indices_list: + local_indices = indices.narrow(0, indices_start_position, indices_end_position - indices_start_position) + torch.sub(local_indices, self.idx_offset_list[i], out=local_indices) + local_indices_list.append(local_indices) + # 2. local_offsets_list: + if i + 1 == len(self.assigned_table_list): + # till-the-end special case + if not self.include_last_offset: + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size) + else: + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1) + torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets) + local_offsets_list.append(local_offsets) + else: + temp_holder = offsets[batch_size * handle_table].item() + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size) + torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets) + offset_pre_end = offsets[batch_size * (handle_table + 1)] + offset_pre_end - temp_holder + local_offsets_list.append(local_offsets) + ''' # 1. local_indices_list: local_indices_list.append( indices.narrow(0, indices_start_position, @@ -103,21 +126,20 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): # till-the-end special case if not self.include_last_offset: local_offsets = offsets.narrow(0, batch_size * handle_table, - batch_size).add(offset_pre_end - offsets[batch_size * - (handle_table)]) + batch_size).add(offset_pre_end - offsets[batch_size + * (handle_table)]) else: - local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + - 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + + 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) local_offsets_list.append(local_offsets) else: - local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + - 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + + 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) offset_pre_end = local_offsets[-1] local_offsets_list.append(local_offsets[:-1]) # 3. local_per_sample_weights_list: if per_sample_weights != None: local_per_sample_weights_list.append(per_sample_weights[indices_start_position:indices_end_position]) - local_indices = torch.cat(local_indices_list, 0) local_offsets = torch.cat(local_offsets_list, 0) local_per_sample_weights = None