diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py index 464efe23d..35faa67b5 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py @@ -64,7 +64,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module): self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list] self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list] self.global_tables_num = len(embedding_bag_config_list) - self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0) + self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda() self.assigned_table_list: List[int] = [] for i, rank in enumerate(self.rank_of_tables):