mirror of https://github.com/hpcaitech/ColossalAI
[embedding] fix a bug in table wise sharding (#1538)
parent
87134524fd
commit
521078ffc9
|
@ -64,7 +64,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
|
|||
self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list]
|
||||
self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list]
|
||||
self.global_tables_num = len(embedding_bag_config_list)
|
||||
self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0)
|
||||
self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda()
|
||||
|
||||
self.assigned_table_list: List[int] = []
|
||||
for i, rank in enumerate(self.rank_of_tables):
|
||||
|
|
Loading…
Reference in New Issue