mirror of https://github.com/hpcaitech/ColossalAI
154 lines
5.9 KiB
Python
154 lines
5.9 KiB
Python
from functools import partial
|
|
from typing import Dict, Union
|
|
|
|
import torch.nn as nn
|
|
|
|
from colossalai.shardformer.layer import (
|
|
Linear1D_Col,
|
|
Linear1D_Row,
|
|
PaddingEmbedding,
|
|
PaddingLMHead,
|
|
RMSNorm,
|
|
VocabParallelEmbedding1D,
|
|
VocabParallelLMHead1D,
|
|
)
|
|
|
|
from ..modeling.gemma2 import Gemma2PipelineForwards
|
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
|
|
|
__all__ = ["Gemma2Policy", "Gemma2ForCausalLMPolicy"]
|
|
|
|
|
|
class Gemma2Policy(Policy):
|
|
def config_sanity_check(self):
|
|
pass
|
|
|
|
def preprocess(self):
|
|
self.tie_weight = self.tie_weight_check()
|
|
return self.model
|
|
|
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer, Gemma2Model
|
|
|
|
policy = {}
|
|
|
|
embedding_cls = None
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
embedding_cls = VocabParallelEmbedding1D
|
|
else:
|
|
if self.tie_weight:
|
|
embedding_cls = PaddingEmbedding
|
|
|
|
norm_cls = RMSNorm
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
tp_size = self.shard_config.tensor_parallel_size
|
|
num_q_heads = self.model.config.num_attention_heads // tp_size
|
|
decoder_attribute_replacement = {
|
|
"self_attn.hidden_size": self.model.config.hidden_size // tp_size,
|
|
"self_attn.num_heads": num_q_heads,
|
|
}
|
|
num_kv_heads = self.model.config.num_key_value_heads // tp_size
|
|
decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads
|
|
policy[Gemma2DecoderLayer] = ModulePolicyDescription(
|
|
attribute_replacement=decoder_attribute_replacement,
|
|
sub_module_replacement=[
|
|
SubModuleReplacementDescription(suffix="mlp.gate_proj", target_module=Linear1D_Col),
|
|
SubModuleReplacementDescription(suffix="mlp.up_proj", target_module=Linear1D_Col),
|
|
SubModuleReplacementDescription(suffix="mlp.down_proj", target_module=Linear1D_Row),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.q_proj",
|
|
target_module=Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.k_proj",
|
|
target_module=Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.v_proj",
|
|
target_module=Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.o_proj",
|
|
target_module=Linear1D_Row,
|
|
),
|
|
],
|
|
)
|
|
|
|
if embedding_cls is not None:
|
|
self.append_or_create_submodule_replacement(
|
|
description=SubModuleReplacementDescription(
|
|
suffix="embed_tokens",
|
|
target_module=embedding_cls,
|
|
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
|
),
|
|
policy=policy,
|
|
target_key=Gemma2Model,
|
|
)
|
|
|
|
self.append_or_create_submodule_replacement(
|
|
description=[
|
|
SubModuleReplacementDescription(suffix="input_layernorm", target_module=norm_cls),
|
|
SubModuleReplacementDescription(suffix="pre_feedforward_layernorm", target_module=norm_cls),
|
|
SubModuleReplacementDescription(suffix="post_feedforward_layernorm", target_module=norm_cls),
|
|
SubModuleReplacementDescription(suffix="post_attention_layernorm", target_module=norm_cls),
|
|
],
|
|
policy=policy,
|
|
target_key=Gemma2DecoderLayer,
|
|
)
|
|
|
|
self.append_or_create_submodule_replacement(
|
|
description=SubModuleReplacementDescription(
|
|
suffix="norm",
|
|
target_module=norm_cls,
|
|
),
|
|
policy=policy,
|
|
target_key=Gemma2Model,
|
|
)
|
|
return policy
|
|
|
|
def postprocess(self):
|
|
return self.model
|
|
|
|
|
|
class Gemma2ForCausalLMPolicy(Gemma2Policy):
|
|
def module_policy(self):
|
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
|
|
|
|
policy = super().module_policy()
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
self.append_or_create_submodule_replacement(
|
|
description=SubModuleReplacementDescription(
|
|
suffix="lm_head",
|
|
target_module=VocabParallelLMHead1D,
|
|
kwargs=dict(
|
|
gather_output=not self.shard_config.parallel_output,
|
|
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
|
|
),
|
|
),
|
|
policy=policy,
|
|
target_key=Gemma2ForCausalLM,
|
|
)
|
|
if self.shard_config.parallel_output:
|
|
method_replacement = {
|
|
"forward": partial(
|
|
Gemma2PipelineForwards.gemma2_for_causal_lm_forward, shard_config=self.shard_config
|
|
)
|
|
}
|
|
self.append_or_create_method_replacement(
|
|
description=method_replacement, policy=policy, target_key=Gemma2ForCausalLM
|
|
)
|
|
else:
|
|
self.append_or_create_submodule_replacement(
|
|
description=SubModuleReplacementDescription(
|
|
suffix="lm_head",
|
|
target_module=PaddingLMHead,
|
|
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
|
|
),
|
|
policy=policy,
|
|
target_key=Gemma2ForCausalLM,
|
|
)
|
|
|
|
return policy
|