[embedding] fix a bug in table wise sharding (#1538)

pull/1532/head
Jiarui Fang 2 years ago committed by GitHub
parent 87134524fd
commit 521078ffc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -64,7 +64,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list]
self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list]
self.global_tables_num = len(embedding_bag_config_list)
self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0)
self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda()
self.assigned_table_list: List[int] = []
for i, rank in enumerate(self.rank_of_tables):

Loading…
Cancel
Save