Browse Source

[shardformer] fix embedding

pull/4445/head
ver217 1 year ago committed by Hongxin Liu
parent
commit
73a4144b91
  1. 3
      colossalai/shardformer/layer/embedding.py

3
colossalai/shardformer/layer/embedding.py

@ -214,6 +214,9 @@ class VocabParallelEmbedding1D(ParallelModule):
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
# padding index
self.padding_idx = self._select_padding_idx(padding_idx)
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)

Loading…
Cancel
Save