diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index e10a7ed7d..92c709218 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -16,6 +16,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig from ..layer import cross_entropy_1d +from ..layer._operation import _gather try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask @@ -288,6 +289,9 @@ class LlamaPipelineForwards: shift_logits = shift_logits.view(-1, self.config.vocab_size) loss = loss_fct(shift_logits, shift_labels) + if not shard_config.parallel_output: + logits = _gather(logits, -1, shard_config.tensor_parallel_process_group) + if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output @@ -588,6 +592,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): shift_logits = shift_logits.view(-1, self.config.vocab_size) loss = loss_fct(shift_logits, shift_labels) + if not shard_config.parallel_output: + logits = _gather(logits, -1, shard_config.tensor_parallel_process_group) + if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index b5c9e66e0..415fc6dd5 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -34,6 +34,7 @@ class ShardConfig: enable_all_optimization: bool = False enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False + parallel_output = True extra_kwargs: Dict[str, Any] = field(default_factory=dict) # pipeline_parallel_size: int # data_parallel_size: int