From e6707a6e8d81d584597d7b8e5d8578158544d8f5 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 27 Mar 2024 11:21:03 +0800 Subject: [PATCH] [format] applied code formatting on changed files in pull request 5510 (#5517) Co-authored-by: github-actions --- colossalai/shardformer/modeling/gpt2.py | 1 - colossalai/shardformer/modeling/llama.py | 2 -- colossalai/shardformer/policies/gpt2.py | 8 ++++++-- colossalai/shardformer/policies/llama.py | 10 ++++++++-- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 72f923bf0..ea22cfb15 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -1302,7 +1302,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group ) - if not shard_config.parallel_output: lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 1f17144f5..29dc8200f 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -15,10 +15,8 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig - from ..layer import ColoAttention, cross_entropy_1d - try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index fcf40fa39..5b43ecaed 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -291,13 +291,17 @@ class GPT2LMHeadModelPolicy(GPT2Policy): GPT2LMHeadModel: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output} + suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": not self.shard_config.parallel_output}, ) ], ) } if self.shard_config.parallel_output: - addon_module[GPT2LMHeadModel].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} + addon_module[GPT2LMHeadModel].method_replacement = { + "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + } module_policy.update(addon_module) if self.pipeline_stage_manager is not None: diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 37c2c261b..db8468713 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -265,12 +265,18 @@ class LlamaForCausalLMPolicy(LlamaPolicy): new_item = { LlamaForCausalLM: ModulePolicyDescription( sub_module_replacement=[ - SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output}) + SubModuleReplacementDescription( + suffix="lm_head", + target_module=Linear1D_Col, + kwargs={"gather_output": not self.shard_config.parallel_output}, + ) ], ) } if self.shard_config.parallel_output: - new_item[LlamaForCausalLM].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} + new_item[LlamaForCausalLM].method_replacement = { + "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + } policy.update(new_item) if self.pipeline_stage_manager: