diff --git a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py index c816f85b9..5ea3678f1 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py @@ -97,12 +97,13 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): evict_strategy=self.evict_strategy) self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) - def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None): - with torch.no_grad(): - reorder_ids = self.cache_weight_mgr.prepare_ids(input) + def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None, cache_op=True): + if cache_op: + with torch.no_grad(): + input = self.cache_weight_mgr.prepare_ids(input) - embeddings = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, - self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, + embeddings = F.embedding_bag(input.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, per_sample_weights, self.include_last_offset, self.padding_idx) if shape_hook is not None: embeddings = shape_hook(embeddings) diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py index 1171496fc..24ec7599c 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py @@ -72,11 +72,19 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): compute_attr=ComputePattern.TP1D) return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec) - def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1): - with torch.no_grad(): - reorder_ids = self.cache_weight_mgr.prepare_ids(indices) - output_shard = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, - self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, + def forward(self, + indices, + offsets=None, + per_sample_weights=None, + shape_hook=None, + scatter_dim=0, + gather_dim=-1, + cache_op: bool = True): + if cache_op: + with torch.no_grad(): + indices = self.cache_weight_mgr.prepare_ids(indices) + output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, per_sample_weights, self.include_last_offset, self.padding_idx) if shape_hook is not None: output_shard = shape_hook(output_shard) diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py index ab8a232b1..8f8180ed1 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py @@ -86,7 +86,8 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None, - already_split_along_rank=True): + already_split_along_rank=True, + cache_op=True): if not already_split_along_rank: # not recommanded. it takes time. batch_size = (offsets.shape[0]) // self.global_tables_num @@ -96,9 +97,10 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): # recommanded. batch_size = (offsets.shape[0]) // len(self.assigned_table_list) local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights - with torch.no_grad(): - reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices) - local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets, + if cache_op: + with torch.no_grad(): + indices = self.cache_weight_mgr.prepare_ids(local_indices) + local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, local_per_sample_weights, self.include_last_offset, self.padding_idx) local_output = torch.cat(local_output.split(batch_size), 1)