mirror of https://github.com/hpcaitech/ColossalAI
parent
dcdd8a5ef7
commit
0a25e16e46
|
@ -16,6 +16,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import cross_entropy_1d
|
||||
from ..layer._operation import _gather
|
||||
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
||||
|
@ -288,6 +289,9 @@ class LlamaPipelineForwards:
|
|||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not shard_config.parallel_output:
|
||||
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
@ -588,6 +592,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not shard_config.parallel_output:
|
||||
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
|
|
@ -34,6 +34,7 @@ class ShardConfig:
|
|||
enable_all_optimization: bool = False
|
||||
enable_sequence_parallelism: bool = False
|
||||
enable_sequence_overlap: bool = False
|
||||
parallel_output = True
|
||||
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||
# pipeline_parallel_size: int
|
||||
# data_parallel_size: int
|
||||
|
|
Loading…
Reference in New Issue