[format] applied code formatting on changed files in pull request 5510 (#5517)

Co-authored-by: github-actions <github-actions@github.com>
pull/4309/merge
github-actions[bot] 2024-03-27 11:21:03 +08:00 committed by GitHub
parent 19e1a5cf16
commit e6707a6e8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 14 additions and 7 deletions

View File

@ -1302,7 +1302,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)
if not shard_config.parallel_output:
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)

View File

@ -15,10 +15,8 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, cross_entropy_1d
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask

View File

@ -291,13 +291,17 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
GPT2LMHeadModel: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output}
suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": not self.shard_config.parallel_output},
)
],
)
}
if self.shard_config.parallel_output:
addon_module[GPT2LMHeadModel].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
addon_module[GPT2LMHeadModel].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
module_policy.update(addon_module)
if self.pipeline_stage_manager is not None:

View File

@ -265,12 +265,18 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output})
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs={"gather_output": not self.shard_config.parallel_output},
)
],
)
}
if self.shard_config.parallel_output:
new_item[LlamaForCausalLM].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
new_item[LlamaForCausalLM].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
policy.update(new_item)
if self.pipeline_stage_manager: