mirror of https://github.com/hpcaitech/ColossalAI
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.cipull/6121/head
parent
4389089f1b
commit
753db97eb3
|
@ -1,27 +1,18 @@
|
|||
import math
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.utils.checkpoint
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.models.gemma2.modeling_gemma2 import (
|
||||
Gemma2ForCausalLM,
|
||||
Gemma2Model,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM, Gemma2Model
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward
|
||||
from colossalai.shardformer.layer._operation import gather_sp_output
|
||||
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, RingAttention, dist_cross_entropy
|
||||
from ..layer import RingAttention, dist_cross_entropy
|
||||
|
||||
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
||||
|
||||
|
@ -88,7 +79,7 @@ class Gemma2PipelineForwards:
|
|||
|
||||
# Support SP + PP
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
shard_config.sequence_parallel_process_group
|
||||
sp_size = shard_config.sequence_parallel_size
|
||||
# Generating full positions ids for modes that gather sequence before attn
|
||||
if stage_manager and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()):
|
||||
|
@ -97,7 +88,7 @@ class Gemma2PipelineForwards:
|
|||
past_seen_tokens = 0
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
|
||||
|
||||
seq_length_with_past = seq_length + past_seen_tokens
|
||||
seq_length + past_seen_tokens
|
||||
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
|
@ -108,11 +99,13 @@ class Gemma2PipelineForwards:
|
|||
if use_cache:
|
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values, output_attentions)
|
||||
attn_kwargs: torch.Tensor = self._update_causal_mask(
|
||||
attention_mask, hidden_states, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
|
|
@ -141,7 +141,9 @@ class LlamaPipelineForwards:
|
|||
invert=(sp_mode != "ring_attn"),
|
||||
)
|
||||
else:
|
||||
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values, output_attentions)
|
||||
attn_kwargs: torch.Tensor = self._update_causal_mask(
|
||||
attention_mask, hidden_states, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
|
||||
# Support SP + PP. Later stages have already received the split input.
|
||||
split_input = disable_pp or stage_manager.is_first_stage()
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from functools import partial
|
||||
from typing import Callable, Dict, List, Union
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.shardformer.layer import (
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
|
@ -13,10 +13,12 @@ from colossalai.shardformer.layer import (
|
|||
VocabParallelLMHead1D,
|
||||
)
|
||||
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
from ..modeling.gemma2 import Gemma2PipelineForwards
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["Gemma2Policy", "Gemma2ForCausalLMPolicy"]
|
||||
|
||||
|
||||
class Gemma2Policy(Policy):
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
@ -26,10 +28,8 @@ class Gemma2Policy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.gemma2.modeling_gemma2 import (
|
||||
Gemma2DecoderLayer,
|
||||
Gemma2Model,
|
||||
)
|
||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer, Gemma2Model
|
||||
|
||||
policy = {}
|
||||
|
||||
embedding_cls = None
|
||||
|
@ -53,15 +53,9 @@ class Gemma2Policy(Policy):
|
|||
policy[Gemma2DecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row),
|
||||
SubModuleReplacementDescription(suffix="mlp.gate_proj", target_module=Linear1D_Col),
|
||||
SubModuleReplacementDescription(suffix="mlp.up_proj", target_module=Linear1D_Col),
|
||||
SubModuleReplacementDescription(suffix="mlp.down_proj", target_module=Linear1D_Row),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
|
@ -78,7 +72,7 @@ class Gemma2Policy(Policy):
|
|||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
|
@ -94,18 +88,10 @@ class Gemma2Policy(Policy):
|
|||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=norm_cls),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="pre_feedforward_layernorm",
|
||||
target_module=norm_cls),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_feedforward_layernorm",
|
||||
target_module=norm_cls),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=norm_cls),
|
||||
SubModuleReplacementDescription(suffix="input_layernorm", target_module=norm_cls),
|
||||
SubModuleReplacementDescription(suffix="pre_feedforward_layernorm", target_module=norm_cls),
|
||||
SubModuleReplacementDescription(suffix="post_feedforward_layernorm", target_module=norm_cls),
|
||||
SubModuleReplacementDescription(suffix="post_attention_layernorm", target_module=norm_cls),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=Gemma2DecoderLayer,
|
||||
|
@ -145,7 +131,11 @@ class Gemma2ForCausalLMPolicy(Gemma2Policy):
|
|||
target_key=Gemma2ForCausalLM,
|
||||
)
|
||||
if self.shard_config.parallel_output:
|
||||
method_replacement = {"forward": partial(Gemma2PipelineForwards.gemma2_for_causal_lm_forward, shard_config=self.shard_config)}
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
Gemma2PipelineForwards.gemma2_for_causal_lm_forward, shard_config=self.shard_config
|
||||
)
|
||||
}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=Gemma2ForCausalLM
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue