[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
pull/5684/head
pre-commit-ci[bot] 7 months ago
parent 108ddfb795
commit ca56b93d83

@ -21,7 +21,9 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d from ..layer import cross_entropy_1d
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)

@ -21,9 +21,9 @@ from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.opt import ( from ..modeling.opt import (
OPTPipelineForwards, OPTPipelineForwards,
get_jit_fused_opt_decoder_layer_forward, get_jit_fused_opt_decoder_layer_forward,
get_lm_forward_with_dist_cross_entropy,
get_opt_decoder_forward_for_flash_attention, get_opt_decoder_forward_for_flash_attention,
get_opt_flash_attention_forward, get_opt_flash_attention_forward,
get_lm_forward_with_dist_cross_entropy
) )
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -271,20 +271,16 @@ class OPTForCausalLMPolicy(OPTPolicy):
target_module=VocabParallelLMHead1D, target_module=VocabParallelLMHead1D,
kwargs=dict( kwargs=dict(
gather_output=not self.shard_config.parallel_output, gather_output=not self.shard_config.parallel_output,
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
), ),
), ),
policy=policy, policy=policy,
target_key=OPTForCausalLM, target_key=OPTForCausalLM,
) )
if self.shard_config.parallel_output: if self.shard_config.parallel_output:
method_replacement = { method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description=method_replacement, description=method_replacement, policy=policy, target_key=OPTForCausalLM
policy=policy,
target_key=OPTForCausalLM
) )
else: else:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(

Loading…
Cancel
Save