diff --git a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py index 5ea3678f1..356d77bd2 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py @@ -64,6 +64,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): cuda_row_num = int(num_embeddings * cache_ratio) # configure weight & cache self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight) + self.cache_op = True def _weight_alloc(self, dtype, device): weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device) @@ -97,8 +98,8 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): evict_strategy=self.evict_strategy) self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) - def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None, cache_op=True): - if cache_op: + def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None): + if self.cache_op: with torch.no_grad(): input = self.cache_weight_mgr.prepare_ids(input) @@ -119,6 +120,9 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): def parameters(self, recurse: bool = True) -> Iterator[Parameter]: yield self.cache_weight_mgr.cuda_cached_weight + def set_cache_op(self, cache_op: bool = True): + self.cache_op = cache_op + ############################# Perf Log ################################### diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py index 24ec7599c..28e6e0575 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py @@ -60,6 +60,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight, evict_strategy) + self.cache_op = True def _weight_alloc(self, dtype, device): weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype) @@ -72,15 +73,16 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): compute_attr=ComputePattern.TP1D) return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec) - def forward(self, - indices, - offsets=None, - per_sample_weights=None, - shape_hook=None, - scatter_dim=0, - gather_dim=-1, - cache_op: bool = True): - if cache_op: + def forward( + self, + indices, + offsets=None, + per_sample_weights=None, + shape_hook=None, + scatter_dim=0, + gather_dim=-1, + ): + if self.cache_op: with torch.no_grad(): indices = self.cache_weight_mgr.prepare_ids(indices) output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, @@ -94,6 +96,9 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): gather_dim=gather_dim) return output_full + def set_cache_op(self, cache_op: bool = True): + self.cache_op = cache_op + @classmethod def from_pretrained( cls, diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py index 8f8180ed1..60803b928 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py @@ -81,13 +81,16 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): for rank in self.rank_of_tables: self.embedding_dim_per_rank[rank] += embedding_dim - def forward(self, - indices: torch.Tensor, - offsets: torch.Tensor = None, - per_sample_weights=None, - shape_hook=None, - already_split_along_rank=True, - cache_op=True): + self.cache_op = True + + def forward( + self, + indices: torch.Tensor, + offsets: torch.Tensor = None, + per_sample_weights=None, + shape_hook=None, + already_split_along_rank=True, + ): if not already_split_along_rank: # not recommanded. it takes time. batch_size = (offsets.shape[0]) // self.global_tables_num @@ -97,7 +100,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): # recommanded. batch_size = (offsets.shape[0]) // len(self.assigned_table_list) local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights - if cache_op: + if self.cache_op: with torch.no_grad(): indices = self.cache_weight_mgr.prepare_ids(local_indices) local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets, @@ -185,6 +188,9 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0) return local_indices, local_offsets, local_per_sample_weights + def set_cache_op(self, cache_op: bool = True): + self.cache_op = cache_op + def print_comm_stats_(self): self.cache_weight_mgr.print_comm_stats()