mirror of https://github.com/hpcaitech/ColossalAI
fix the mistral model
parent
a8408b4d31
commit
4e50cce26b
|
@ -683,12 +683,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if self.config.pretraining_tp > 1:
|
logits = self.lm_head(hidden_states)
|
||||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
|
||||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
||||||
logits = torch.cat(logits, dim=-1)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states)
|
|
||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
|
|
Loading…
Reference in New Issue