remove useless code

pull/5684/head
wangbluo 2024-05-01 09:23:43 +00:00
parent 9efc79ef24
commit 2632916329
2 changed files with 0 additions and 2 deletions

View File

@ -270,7 +270,6 @@ class MistralForwards:
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
#shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)

View File

@ -277,7 +277,6 @@ class MistralForCausalLMPolicy(MistralPolicy):
suffix="lm_head",
target_module=VocabParallelLMHead1D,
kwargs={
#gather_output=True,
"gather_output": not self.shard_config.parallel_output,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},