mirror of https://github.com/hpcaitech/ColossalAI
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.cipull/6122/head
parent
4389089f1b
commit
753db97eb3
|
@ -1,27 +1,18 @@
|
||||||
import math
|
from typing import List, Optional
|
||||||
import warnings
|
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from transformers.cache_utils import DynamicCache
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from transformers.modeling_outputs import (
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM, Gemma2Model
|
||||||
BaseModelOutputWithPast,
|
|
||||||
CausalLMOutputWithPast,
|
|
||||||
)
|
|
||||||
from transformers.models.gemma2.modeling_gemma2 import (
|
|
||||||
Gemma2ForCausalLM,
|
|
||||||
Gemma2Model,
|
|
||||||
)
|
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward
|
from colossalai.shardformer.layer._operation import gather_sp_output
|
||||||
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
|
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
from ..layer import ColoAttention, RingAttention, dist_cross_entropy
|
from ..layer import RingAttention, dist_cross_entropy
|
||||||
|
|
||||||
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
||||||
|
|
||||||
|
@ -88,7 +79,7 @@ class Gemma2PipelineForwards:
|
||||||
|
|
||||||
# Support SP + PP
|
# Support SP + PP
|
||||||
sp_mode = shard_config.sequence_parallelism_mode
|
sp_mode = shard_config.sequence_parallelism_mode
|
||||||
sp_group = shard_config.sequence_parallel_process_group
|
shard_config.sequence_parallel_process_group
|
||||||
sp_size = shard_config.sequence_parallel_size
|
sp_size = shard_config.sequence_parallel_size
|
||||||
# Generating full positions ids for modes that gather sequence before attn
|
# Generating full positions ids for modes that gather sequence before attn
|
||||||
if stage_manager and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()):
|
if stage_manager and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()):
|
||||||
|
@ -97,7 +88,7 @@ class Gemma2PipelineForwards:
|
||||||
past_seen_tokens = 0
|
past_seen_tokens = 0
|
||||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
|
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
|
||||||
|
|
||||||
seq_length_with_past = seq_length + past_seen_tokens
|
seq_length + past_seen_tokens
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||||
|
@ -112,7 +103,9 @@ class Gemma2PipelineForwards:
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values, output_attentions)
|
attn_kwargs: torch.Tensor = self._update_causal_mask(
|
||||||
|
attention_mask, hidden_states, cache_position, past_key_values, output_attentions
|
||||||
|
)
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
|
@ -141,7 +141,9 @@ class LlamaPipelineForwards:
|
||||||
invert=(sp_mode != "ring_attn"),
|
invert=(sp_mode != "ring_attn"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values, output_attentions)
|
attn_kwargs: torch.Tensor = self._update_causal_mask(
|
||||||
|
attention_mask, hidden_states, cache_position, past_key_values, output_attentions
|
||||||
|
)
|
||||||
|
|
||||||
# Support SP + PP. Later stages have already received the split input.
|
# Support SP + PP. Later stages have already received the split input.
|
||||||
split_input = disable_pp or stage_manager.is_first_stage()
|
split_input = disable_pp or stage_manager.is_first_stage()
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Dict, List, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
|
||||||
from colossalai.shardformer.layer import (
|
from colossalai.shardformer.layer import (
|
||||||
Linear1D_Col,
|
Linear1D_Col,
|
||||||
Linear1D_Row,
|
Linear1D_Row,
|
||||||
|
@ -13,10 +13,12 @@ from colossalai.shardformer.layer import (
|
||||||
VocabParallelLMHead1D,
|
VocabParallelLMHead1D,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
|
||||||
from ..modeling.gemma2 import Gemma2PipelineForwards
|
from ..modeling.gemma2 import Gemma2PipelineForwards
|
||||||
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = ["Gemma2Policy", "Gemma2ForCausalLMPolicy"]
|
__all__ = ["Gemma2Policy", "Gemma2ForCausalLMPolicy"]
|
||||||
|
|
||||||
|
|
||||||
class Gemma2Policy(Policy):
|
class Gemma2Policy(Policy):
|
||||||
def config_sanity_check(self):
|
def config_sanity_check(self):
|
||||||
pass
|
pass
|
||||||
|
@ -26,10 +28,8 @@ class Gemma2Policy(Policy):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
from transformers.models.gemma2.modeling_gemma2 import (
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer, Gemma2Model
|
||||||
Gemma2DecoderLayer,
|
|
||||||
Gemma2Model,
|
|
||||||
)
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
embedding_cls = None
|
embedding_cls = None
|
||||||
|
@ -53,15 +53,9 @@ class Gemma2Policy(Policy):
|
||||||
policy[Gemma2DecoderLayer] = ModulePolicyDescription(
|
policy[Gemma2DecoderLayer] = ModulePolicyDescription(
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(suffix="mlp.gate_proj", target_module=Linear1D_Col),
|
||||||
suffix="mlp.gate_proj",
|
SubModuleReplacementDescription(suffix="mlp.up_proj", target_module=Linear1D_Col),
|
||||||
target_module=Linear1D_Col),
|
SubModuleReplacementDescription(suffix="mlp.down_proj", target_module=Linear1D_Row),
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="mlp.up_proj",
|
|
||||||
target_module=Linear1D_Col),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="mlp.down_proj",
|
|
||||||
target_module=Linear1D_Row),
|
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.q_proj",
|
suffix="self_attn.q_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
|
@ -78,7 +72,7 @@ class Gemma2Policy(Policy):
|
||||||
suffix="self_attn.o_proj",
|
suffix="self_attn.o_proj",
|
||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
),
|
),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
if embedding_cls is not None:
|
if embedding_cls is not None:
|
||||||
|
@ -94,18 +88,10 @@ class Gemma2Policy(Policy):
|
||||||
|
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=[
|
description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(suffix="input_layernorm", target_module=norm_cls),
|
||||||
suffix="input_layernorm",
|
SubModuleReplacementDescription(suffix="pre_feedforward_layernorm", target_module=norm_cls),
|
||||||
target_module=norm_cls),
|
SubModuleReplacementDescription(suffix="post_feedforward_layernorm", target_module=norm_cls),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(suffix="post_attention_layernorm", target_module=norm_cls),
|
||||||
suffix="pre_feedforward_layernorm",
|
|
||||||
target_module=norm_cls),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="post_feedforward_layernorm",
|
|
||||||
target_module=norm_cls),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="post_attention_layernorm",
|
|
||||||
target_module=norm_cls),
|
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=Gemma2DecoderLayer,
|
target_key=Gemma2DecoderLayer,
|
||||||
|
@ -145,7 +131,11 @@ class Gemma2ForCausalLMPolicy(Gemma2Policy):
|
||||||
target_key=Gemma2ForCausalLM,
|
target_key=Gemma2ForCausalLM,
|
||||||
)
|
)
|
||||||
if self.shard_config.parallel_output:
|
if self.shard_config.parallel_output:
|
||||||
method_replacement = {"forward": partial(Gemma2PipelineForwards.gemma2_for_causal_lm_forward, shard_config=self.shard_config)}
|
method_replacement = {
|
||||||
|
"forward": partial(
|
||||||
|
Gemma2PipelineForwards.gemma2_for_causal_lm_forward, shard_config=self.shard_config
|
||||||
|
)
|
||||||
|
}
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description=method_replacement, policy=policy, target_key=Gemma2ForCausalLM
|
description=method_replacement, policy=policy, target_key=Gemma2ForCausalLM
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue