Browse Source

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
pull/5684/head
pre-commit-ci[bot] 7 months ago
parent
commit
ca56b93d83
  1. 10
      colossalai/shardformer/modeling/opt.py
  2. 14
      colossalai/shardformer/policies/opt.py

10
colossalai/shardformer/modeling/opt.py

@ -21,7 +21,9 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
logger = logging.get_logger(__name__)
@ -351,7 +353,7 @@ class OPTPipelineForwards:
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
@ -987,8 +989,8 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
)
#loss_fct = CrossEntropyLoss()
#loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
# loss_fct = CrossEntropyLoss()
# loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
@ -1002,4 +1004,4 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
attentions=outputs.attentions,
)
return forward
return forward

14
colossalai/shardformer/policies/opt.py

@ -21,9 +21,9 @@ from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.opt import (
OPTPipelineForwards,
get_jit_fused_opt_decoder_layer_forward,
get_lm_forward_with_dist_cross_entropy,
get_opt_decoder_forward_for_flash_attention,
get_opt_flash_attention_forward,
get_lm_forward_with_dist_cross_entropy
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -270,21 +270,17 @@ class OPTForCausalLMPolicy(OPTPolicy):
suffix="lm_head",
target_module=VocabParallelLMHead1D,
kwargs=dict(
gather_output=not self.shard_config.parallel_output,
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
gather_output=not self.shard_config.parallel_output,
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
),
),
policy=policy,
target_key=OPTForCausalLM,
)
if self.shard_config.parallel_output:
method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
self.append_or_create_method_replacement(
description=method_replacement,
policy=policy,
target_key=OPTForCausalLM
description=method_replacement, policy=policy, target_key=OPTForCausalLM
)
else:
self.append_or_create_submodule_replacement(

Loading…
Cancel
Save