[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
pull/6121/head
pre-commit-ci[bot] 2024-11-09 08:02:25 +00:00
parent 4389089f1b
commit 753db97eb3
3 changed files with 34 additions and 49 deletions

View File

@ -1,27 +1,18 @@
import math
import warnings
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Optional
import torch
import torch.distributed
import torch.utils.checkpoint
from transformers.cache_utils import DynamicCache
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2ForCausalLM,
Gemma2Model,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM, Gemma2Model
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward
from colossalai.shardformer.layer._operation import gather_sp_output
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, RingAttention, dist_cross_entropy
from ..layer import RingAttention, dist_cross_entropy
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
@ -88,7 +79,7 @@ class Gemma2PipelineForwards:
# Support SP + PP
sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group
shard_config.sequence_parallel_process_group
sp_size = shard_config.sequence_parallel_size
# Generating full positions ids for modes that gather sequence before attn
if stage_manager and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()):
@ -97,7 +88,7 @@ class Gemma2PipelineForwards:
past_seen_tokens = 0
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
seq_length_with_past = seq_length + past_seen_tokens
seq_length + past_seen_tokens
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
@ -108,11 +99,13 @@ class Gemma2PipelineForwards:
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values, output_attentions)
attn_kwargs: torch.Tensor = self._update_causal_mask(
attention_mask, hidden_states, cache_position, past_key_values, output_attentions
)
# decoder layers
all_hidden_states = () if output_hidden_states else None

View File

@ -141,7 +141,9 @@ class LlamaPipelineForwards:
invert=(sp_mode != "ring_attn"),
)
else:
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values, output_attentions)
attn_kwargs: torch.Tensor = self._update_causal_mask(
attention_mask, hidden_states, cache_position, past_key_values, output_attentions
)
# Support SP + PP. Later stages have already received the split input.
split_input = disable_pp or stage_manager.is_first_stage()

View File

@ -1,8 +1,8 @@
from functools import partial
from typing import Callable, Dict, List, Union
from typing import Dict, Union
import torch.nn as nn
from torch import Tensor
from colossalai.shardformer.layer import (
Linear1D_Col,
Linear1D_Row,
@ -13,10 +13,12 @@ from colossalai.shardformer.layer import (
VocabParallelLMHead1D,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from ..modeling.gemma2 import Gemma2PipelineForwards
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["Gemma2Policy", "Gemma2ForCausalLMPolicy"]
class Gemma2Policy(Policy):
def config_sanity_check(self):
pass
@ -26,10 +28,8 @@ class Gemma2Policy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2DecoderLayer,
Gemma2Model,
)
from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer, Gemma2Model
policy = {}
embedding_cls = None
@ -53,15 +53,9 @@ class Gemma2Policy(Policy):
policy[Gemma2DecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row),
SubModuleReplacementDescription(suffix="mlp.gate_proj", target_module=Linear1D_Col),
SubModuleReplacementDescription(suffix="mlp.up_proj", target_module=Linear1D_Col),
SubModuleReplacementDescription(suffix="mlp.down_proj", target_module=Linear1D_Row),
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
@ -78,7 +72,7 @@ class Gemma2Policy(Policy):
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
]
],
)
if embedding_cls is not None:
@ -94,18 +88,10 @@ class Gemma2Policy(Policy):
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=norm_cls),
SubModuleReplacementDescription(
suffix="pre_feedforward_layernorm",
target_module=norm_cls),
SubModuleReplacementDescription(
suffix="post_feedforward_layernorm",
target_module=norm_cls),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=norm_cls),
SubModuleReplacementDescription(suffix="input_layernorm", target_module=norm_cls),
SubModuleReplacementDescription(suffix="pre_feedforward_layernorm", target_module=norm_cls),
SubModuleReplacementDescription(suffix="post_feedforward_layernorm", target_module=norm_cls),
SubModuleReplacementDescription(suffix="post_attention_layernorm", target_module=norm_cls),
],
policy=policy,
target_key=Gemma2DecoderLayer,
@ -145,7 +131,11 @@ class Gemma2ForCausalLMPolicy(Gemma2Policy):
target_key=Gemma2ForCausalLM,
)
if self.shard_config.parallel_output:
method_replacement = {"forward": partial(Gemma2PipelineForwards.gemma2_for_causal_lm_forward, shard_config=self.shard_config)}
method_replacement = {
"forward": partial(
Gemma2PipelineForwards.gemma2_for_causal_lm_forward, shard_config=self.shard_config
)
}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=Gemma2ForCausalLM
)