fix the mistral model

pull/5684/head
wangbluo 7 months ago
parent a8408b4d31
commit 4e50cce26b

@ -683,11 +683,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()

Loading…
Cancel
Save