mirror of https://github.com/hpcaitech/ColossalAI
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.cipull/5684/head
parent
108ddfb795
commit
ca56b93d83
|
@ -21,7 +21,9 @@ from transformers.utils import logging
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.layer import ColoAttention
|
from colossalai.shardformer.layer import ColoAttention
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
from ..layer import cross_entropy_1d
|
from ..layer import cross_entropy_1d
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -987,8 +989,8 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
vocab_size=self.lm_head.out_features,
|
vocab_size=self.lm_head.out_features,
|
||||||
)
|
)
|
||||||
#loss_fct = CrossEntropyLoss()
|
# loss_fct = CrossEntropyLoss()
|
||||||
#loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
# loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
|
|
@ -21,9 +21,9 @@ from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||||
from ..modeling.opt import (
|
from ..modeling.opt import (
|
||||||
OPTPipelineForwards,
|
OPTPipelineForwards,
|
||||||
get_jit_fused_opt_decoder_layer_forward,
|
get_jit_fused_opt_decoder_layer_forward,
|
||||||
|
get_lm_forward_with_dist_cross_entropy,
|
||||||
get_opt_decoder_forward_for_flash_attention,
|
get_opt_decoder_forward_for_flash_attention,
|
||||||
get_opt_flash_attention_forward,
|
get_opt_flash_attention_forward,
|
||||||
get_lm_forward_with_dist_cross_entropy
|
|
||||||
)
|
)
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
@ -271,20 +271,16 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
||||||
target_module=VocabParallelLMHead1D,
|
target_module=VocabParallelLMHead1D,
|
||||||
kwargs=dict(
|
kwargs=dict(
|
||||||
gather_output=not self.shard_config.parallel_output,
|
gather_output=not self.shard_config.parallel_output,
|
||||||
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
|
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=OPTForCausalLM,
|
target_key=OPTForCausalLM,
|
||||||
)
|
)
|
||||||
if self.shard_config.parallel_output:
|
if self.shard_config.parallel_output:
|
||||||
method_replacement = {
|
method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
|
||||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
|
||||||
}
|
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description=method_replacement,
|
description=method_replacement, policy=policy, target_key=OPTForCausalLM
|
||||||
policy=policy,
|
|
||||||
target_key=OPTForCausalLM
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
|
|
Loading…
Reference in New Issue