mirror of https://github.com/hpcaitech/ColossalAI
97 lines
5.1 KiB
Python
97 lines
5.1 KiB
Python
from typing import Dict, Union
|
|
|
|
import torch.nn as nn
|
|
|
|
import colossalai.shardformer.layer as col_nn
|
|
|
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
|
|
|
__all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy']
|
|
|
|
|
|
class ChatGLMModelPolicy(Policy):
|
|
|
|
def config_sanity_check(self):
|
|
pass
|
|
|
|
def preprocess(self):
|
|
# Resize embedding
|
|
vocab_size = self.model.config.padded_vocab_size
|
|
world_size = self.shard_config.tensor_parallel_size
|
|
|
|
if vocab_size % world_size != 0:
|
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
|
self.model.resize_token_embeddings(new_vocab_size)
|
|
|
|
return self.model
|
|
|
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
|
from tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock
|
|
|
|
policy = {}
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={},
|
|
sub_module_replacement=[
|
|
SubModuleReplacementDescription(
|
|
suffix="embedding.word_embeddings",
|
|
target_module=col_nn.VocabParallelEmbedding1D,
|
|
)
|
|
])
|
|
|
|
policy[GLMBlock] = ModulePolicyDescription(attribute_replacement={
|
|
"self_attention.num_attention_heads_per_partition":
|
|
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
"self_attention.projection_size":
|
|
(self.model.config.kv_channels * self.model.config.num_attention_heads) //
|
|
self.shard_config.tensor_parallel_size,
|
|
"self_attention.qkv_hidden_size":
|
|
(self.model.config.kv_channels * self.model.config.num_attention_heads * 3) //
|
|
self.shard_config.tensor_parallel_size,
|
|
"self_attention.core_attention.num_attention_heads_per_partition":
|
|
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
"self_attention.core_attention.hidden_size_per_partition":
|
|
self.model.config.kv_channels * self.model.config.num_attention_heads //
|
|
self.shard_config.tensor_parallel_size,
|
|
},
|
|
param_replacement=[],
|
|
sub_module_replacement=[
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attention.query_key_value",
|
|
target_module=col_nn.Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attention.dense",
|
|
target_module=col_nn.Linear1D_Row,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attention.core_attention.attention_dropout",
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
),
|
|
])
|
|
# optimization configuration
|
|
if self.shard_config.enable_fused_normalization:
|
|
if not self.model.config.rmsnorm:
|
|
|
|
self.append_or_create_submodule_replacement(description=[
|
|
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm),
|
|
SubModuleReplacementDescription(suffix="post_attention_layernorm",
|
|
target_module=col_nn.FusedLayerNorm)
|
|
],
|
|
policy=policy,
|
|
target_key=GLMBlock)
|
|
|
|
if self.model.config.post_layer_norm:
|
|
self.append_or_create_submodule_replacement(description=[
|
|
SubModuleReplacementDescription(suffix="encoder.final_layernorm",
|
|
target_module=col_nn.FusedLayerNorm)
|
|
],
|
|
policy=policy,
|
|
target_key=ChatGLMModel)
|
|
|
|
return policy
|
|
|
|
def postprocess(self):
|
|
return self.model
|