[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
pull/6122/head
pre-commit-ci[bot] 2024-11-09 08:02:25 +00:00
parent 4389089f1b
commit 753db97eb3
3 changed files with 34 additions and 49 deletions

View File

@ -1,27 +1,18 @@
import math from typing import List, Optional
import warnings
from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.distributed import torch.distributed
import torch.utils.checkpoint import torch.utils.checkpoint
from transformers.cache_utils import DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_outputs import ( from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM, Gemma2Model
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2ForCausalLM,
Gemma2Model,
)
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward from colossalai.shardformer.layer._operation import gather_sp_output
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, RingAttention, dist_cross_entropy from ..layer import RingAttention, dist_cross_entropy
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
@ -88,7 +79,7 @@ class Gemma2PipelineForwards:
# Support SP + PP # Support SP + PP
sp_mode = shard_config.sequence_parallelism_mode sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group shard_config.sequence_parallel_process_group
sp_size = shard_config.sequence_parallel_size sp_size = shard_config.sequence_parallel_size
# Generating full positions ids for modes that gather sequence before attn # Generating full positions ids for modes that gather sequence before attn
if stage_manager and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()): if stage_manager and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()):
@ -97,7 +88,7 @@ class Gemma2PipelineForwards:
past_seen_tokens = 0 past_seen_tokens = 0
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device) cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
seq_length_with_past = seq_length + past_seen_tokens seq_length + past_seen_tokens
if output_attentions: if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
@ -112,7 +103,9 @@ class Gemma2PipelineForwards:
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values, output_attentions) attn_kwargs: torch.Tensor = self._update_causal_mask(
attention_mask, hidden_states, cache_position, past_key_values, output_attentions
)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None

View File

@ -141,7 +141,9 @@ class LlamaPipelineForwards:
invert=(sp_mode != "ring_attn"), invert=(sp_mode != "ring_attn"),
) )
else: else:
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values, output_attentions) attn_kwargs: torch.Tensor = self._update_causal_mask(
attention_mask, hidden_states, cache_position, past_key_values, output_attentions
)
# Support SP + PP. Later stages have already received the split input. # Support SP + PP. Later stages have already received the split input.
split_input = disable_pp or stage_manager.is_first_stage() split_input = disable_pp or stage_manager.is_first_stage()

View File

@ -1,8 +1,8 @@
from functools import partial from functools import partial
from typing import Callable, Dict, List, Union from typing import Dict, Union
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from colossalai.shardformer.layer import ( from colossalai.shardformer.layer import (
Linear1D_Col, Linear1D_Col,
Linear1D_Row, Linear1D_Row,
@ -13,10 +13,12 @@ from colossalai.shardformer.layer import (
VocabParallelLMHead1D, VocabParallelLMHead1D,
) )
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from ..modeling.gemma2 import Gemma2PipelineForwards from ..modeling.gemma2 import Gemma2PipelineForwards
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["Gemma2Policy", "Gemma2ForCausalLMPolicy"] __all__ = ["Gemma2Policy", "Gemma2ForCausalLMPolicy"]
class Gemma2Policy(Policy): class Gemma2Policy(Policy):
def config_sanity_check(self): def config_sanity_check(self):
pass pass
@ -26,10 +28,8 @@ class Gemma2Policy(Policy):
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.gemma2.modeling_gemma2 import ( from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer, Gemma2Model
Gemma2DecoderLayer,
Gemma2Model,
)
policy = {} policy = {}
embedding_cls = None embedding_cls = None
@ -53,15 +53,9 @@ class Gemma2Policy(Policy):
policy[Gemma2DecoderLayer] = ModulePolicyDescription( policy[Gemma2DecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement, attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(suffix="mlp.gate_proj", target_module=Linear1D_Col),
suffix="mlp.gate_proj", SubModuleReplacementDescription(suffix="mlp.up_proj", target_module=Linear1D_Col),
target_module=Linear1D_Col), SubModuleReplacementDescription(suffix="mlp.down_proj", target_module=Linear1D_Row),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.q_proj", suffix="self_attn.q_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
@ -78,7 +72,7 @@ class Gemma2Policy(Policy):
suffix="self_attn.o_proj", suffix="self_attn.o_proj",
target_module=Linear1D_Row, target_module=Linear1D_Row,
), ),
] ],
) )
if embedding_cls is not None: if embedding_cls is not None:
@ -94,18 +88,10 @@ class Gemma2Policy(Policy):
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(suffix="input_layernorm", target_module=norm_cls),
suffix="input_layernorm", SubModuleReplacementDescription(suffix="pre_feedforward_layernorm", target_module=norm_cls),
target_module=norm_cls), SubModuleReplacementDescription(suffix="post_feedforward_layernorm", target_module=norm_cls),
SubModuleReplacementDescription( SubModuleReplacementDescription(suffix="post_attention_layernorm", target_module=norm_cls),
suffix="pre_feedforward_layernorm",
target_module=norm_cls),
SubModuleReplacementDescription(
suffix="post_feedforward_layernorm",
target_module=norm_cls),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=norm_cls),
], ],
policy=policy, policy=policy,
target_key=Gemma2DecoderLayer, target_key=Gemma2DecoderLayer,
@ -145,7 +131,11 @@ class Gemma2ForCausalLMPolicy(Gemma2Policy):
target_key=Gemma2ForCausalLM, target_key=Gemma2ForCausalLM,
) )
if self.shard_config.parallel_output: if self.shard_config.parallel_output:
method_replacement = {"forward": partial(Gemma2PipelineForwards.gemma2_for_causal_lm_forward, shard_config=self.shard_config)} method_replacement = {
"forward": partial(
Gemma2PipelineForwards.gemma2_for_causal_lm_forward, shard_config=self.shard_config
)
}
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=Gemma2ForCausalLM description=method_replacement, policy=policy, target_key=Gemma2ForCausalLM
) )