diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py index d2f6b7c53..267e88cfb 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py @@ -9,6 +9,7 @@ from colossalai.tensor import ProcessGroup from colossalai.nn._ops._utils import dual_all_to_all_tablewise from typing import List +import time class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): @@ -79,8 +80,43 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): for rank in self.rank_of_tables: self.embedding_dim_per_rank[rank] += embedding_dim - def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None): - batch_size = (offsets.shape[0]) // self.global_tables_num + def forward(self, + indices: torch.Tensor, + offsets: torch.Tensor = None, + per_sample_weights=None, + shape_hook=None, + already_split_along_rank=True): + if not already_split_along_rank: + # not recommanded. it takes time. + batch_size = (offsets.shape[0]) // self.global_tables_num + local_indices, local_offsets, local_per_sample_weights = self.split_along_rank( + batch_size, indices, offsets, per_sample_weights) + else: + # recommanded. + batch_size = (offsets.shape[0]) // len(self.assigned_table_list) + local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights + with torch.no_grad(): + reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices) + local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets, + self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, + local_per_sample_weights, self.include_last_offset, self.padding_idx) + local_output = torch.cat(local_output.split(batch_size), 1) + remains = batch_size % self.world_size + scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)] + output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank) + if shape_hook is not None: + output_full = shape_hook(output_full) + return output_full + + def split_along_rank(self, + batch_size, + indices: torch.Tensor, + offsets: torch.Tensor = None, + per_sample_weights=None): + ''' + if input indices and offsets haven't been splitted along assigned rank, this function will do it. + it takes time. please consider splitting data during batch loading. + ''' local_indices_list: List(torch.Tensor) = [] local_offsets_list: List(torch.Tensor) = [] if per_sample_weights != None: @@ -145,20 +181,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): local_per_sample_weights = None if per_sample_weights != None: local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0) - with torch.no_grad(): - reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices) - - local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets, - self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, - local_per_sample_weights, self.include_last_offset, self.padding_idx) - local_output = torch.cat(local_output.split(batch_size), 1) - - remains = batch_size % self.world_size - scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)] - output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank) - if shape_hook is not None: - output_full = shape_hook(output_full) - return output_full + return local_indices, local_offsets, local_per_sample_weights def print_comm_stats_(self): self.cache_weight_mgr.print_comm_stats() diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py index 3f4dcb0d1..928fbef9c 100644 --- a/tests/test_layers/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -253,7 +253,8 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): in KJT format ''' res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), - torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device)) + torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device), + already_split_along_rank=False) optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device) if rank == 0: