[shardformer]gather llama logits (#5398)

* gather llama logits

* fix
pull/5316/head
flybird11111 2024-02-27 22:44:07 +08:00 committed by GitHub
parent dcdd8a5ef7
commit 0a25e16e46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 0 deletions

View File

@ -16,6 +16,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
from ..layer._operation import _gather
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
@ -288,6 +289,9 @@ class LlamaPipelineForwards:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not shard_config.parallel_output:
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
@ -588,6 +592,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not shard_config.parallel_output:
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

View File

@ -34,6 +34,7 @@ class ShardConfig:
enable_all_optimization: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
parallel_output = True
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# pipeline_parallel_size: int
# data_parallel_size: int