Dev/zero offload (#5858)

* fix llama

* fix llama
dev/zero-offload
Wang Binluo 5 months ago committed by GitHub
parent de3f67d128
commit 868afdb311
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -803,8 +803,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_logits, shift_logits,
shift_labels, shift_labels,
process_group=shard_config.tensor_parallel_process_group, process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
) )
if not return_dict: if not return_dict:

Loading…
Cancel
Save