diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 5855dcc4f..d8ea2c74d 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -803,8 +803,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, ) if not return_dict: