From 753db97eb3e6e58b3e223fd94b77714922bb0acf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 9 Nov 2024 08:02:25 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/modeling/gemma2.py | 29 +++++-------- colossalai/shardformer/modeling/llama.py | 4 +- colossalai/shardformer/policies/gemma2.py | 50 +++++++++-------------- 3 files changed, 34 insertions(+), 49 deletions(-) diff --git a/colossalai/shardformer/modeling/gemma2.py b/colossalai/shardformer/modeling/gemma2.py index 12726a462..75e46c41c 100644 --- a/colossalai/shardformer/modeling/gemma2.py +++ b/colossalai/shardformer/modeling/gemma2.py @@ -1,27 +1,18 @@ -import math -import warnings -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Optional import torch import torch.distributed import torch.utils.checkpoint -from transformers.cache_utils import DynamicCache -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.models.gemma2.modeling_gemma2 import ( - Gemma2ForCausalLM, - Gemma2Model, -) +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM, Gemma2Model from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward +from colossalai.shardformer.layer._operation import gather_sp_output from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, RingAttention, dist_cross_entropy +from ..layer import RingAttention, dist_cross_entropy _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] @@ -88,7 +79,7 @@ class Gemma2PipelineForwards: # Support SP + PP sp_mode = shard_config.sequence_parallelism_mode - sp_group = shard_config.sequence_parallel_process_group + shard_config.sequence_parallel_process_group sp_size = shard_config.sequence_parallel_size # Generating full positions ids for modes that gather sequence before attn if stage_manager and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()): @@ -97,7 +88,7 @@ class Gemma2PipelineForwards: past_seen_tokens = 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device) - seq_length_with_past = seq_length + past_seen_tokens + seq_length + past_seen_tokens if output_attentions: logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") @@ -108,11 +99,13 @@ class Gemma2PipelineForwards: if use_cache: logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False - + if position_ids is None: position_ids = cache_position.unsqueeze(0) - attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values, output_attentions) + attn_kwargs: torch.Tensor = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values, output_attentions + ) # decoder layers all_hidden_states = () if output_hidden_states else None diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 7aa3a8310..5309bcd6d 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -141,7 +141,9 @@ class LlamaPipelineForwards: invert=(sp_mode != "ring_attn"), ) else: - attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values, output_attentions) + attn_kwargs: torch.Tensor = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values, output_attentions + ) # Support SP + PP. Later stages have already received the split input. split_input = disable_pp or stage_manager.is_first_stage() diff --git a/colossalai/shardformer/policies/gemma2.py b/colossalai/shardformer/policies/gemma2.py index 972da1bea..3e8815751 100644 --- a/colossalai/shardformer/policies/gemma2.py +++ b/colossalai/shardformer/policies/gemma2.py @@ -1,8 +1,8 @@ from functools import partial -from typing import Callable, Dict, List, Union +from typing import Dict, Union import torch.nn as nn -from torch import Tensor + from colossalai.shardformer.layer import ( Linear1D_Col, Linear1D_Row, @@ -13,10 +13,12 @@ from colossalai.shardformer.layer import ( VocabParallelLMHead1D, ) -from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from ..modeling.gemma2 import Gemma2PipelineForwards +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + __all__ = ["Gemma2Policy", "Gemma2ForCausalLMPolicy"] + class Gemma2Policy(Policy): def config_sanity_check(self): pass @@ -26,10 +28,8 @@ class Gemma2Policy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.gemma2.modeling_gemma2 import ( - Gemma2DecoderLayer, - Gemma2Model, - ) + from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer, Gemma2Model + policy = {} embedding_cls = None @@ -53,15 +53,9 @@ class Gemma2Policy(Policy): policy[Gemma2DecoderLayer] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="mlp.gate_proj", - target_module=Linear1D_Col), - SubModuleReplacementDescription( - suffix="mlp.up_proj", - target_module=Linear1D_Col), - SubModuleReplacementDescription( - suffix="mlp.down_proj", - target_module=Linear1D_Row), + SubModuleReplacementDescription(suffix="mlp.gate_proj", target_module=Linear1D_Col), + SubModuleReplacementDescription(suffix="mlp.up_proj", target_module=Linear1D_Col), + SubModuleReplacementDescription(suffix="mlp.down_proj", target_module=Linear1D_Row), SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, @@ -78,7 +72,7 @@ class Gemma2Policy(Policy): suffix="self_attn.o_proj", target_module=Linear1D_Row, ), - ] + ], ) if embedding_cls is not None: @@ -94,18 +88,10 @@ class Gemma2Policy(Policy): self.append_or_create_submodule_replacement( description=[ - SubModuleReplacementDescription( - suffix="input_layernorm", - target_module=norm_cls), - SubModuleReplacementDescription( - suffix="pre_feedforward_layernorm", - target_module=norm_cls), - SubModuleReplacementDescription( - suffix="post_feedforward_layernorm", - target_module=norm_cls), - SubModuleReplacementDescription( - suffix="post_attention_layernorm", - target_module=norm_cls), + SubModuleReplacementDescription(suffix="input_layernorm", target_module=norm_cls), + SubModuleReplacementDescription(suffix="pre_feedforward_layernorm", target_module=norm_cls), + SubModuleReplacementDescription(suffix="post_feedforward_layernorm", target_module=norm_cls), + SubModuleReplacementDescription(suffix="post_attention_layernorm", target_module=norm_cls), ], policy=policy, target_key=Gemma2DecoderLayer, @@ -145,7 +131,11 @@ class Gemma2ForCausalLMPolicy(Gemma2Policy): target_key=Gemma2ForCausalLM, ) if self.shard_config.parallel_output: - method_replacement = {"forward": partial(Gemma2PipelineForwards.gemma2_for_causal_lm_forward, shard_config=self.shard_config)} + method_replacement = { + "forward": partial( + Gemma2PipelineForwards.gemma2_for_causal_lm_forward, shard_config=self.shard_config + ) + } self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=Gemma2ForCausalLM )