From 4e50cce26bc5d7aa6c14419c2394bcbc9cc863bf Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Tue, 7 May 2024 09:17:56 +0000 Subject: [PATCH] fix the mistral model --- colossalai/shardformer/modeling/mistral.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 796aeca51..93da71abb 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -683,12 +683,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): ) hidden_states = outputs[0] - if self.config.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) + logits = self.lm_head(hidden_states) logits = logits.float() loss = None