ColossalAI/colossalai/shardformer/policies/chatglm.py

97 lines
5.1 KiB
Python

from typing import Dict, Union
import torch.nn as nn
import colossalai.shardformer.layer as col_nn
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy']
class ChatGLMModelPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
# Resize embedding
vocab_size = self.model.config.padded_vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock
policy = {}
if self.shard_config.enable_tensor_parallelism:
policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embedding.word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
])
policy[GLMBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.num_attention_heads_per_partition":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attention.projection_size":
(self.model.config.kv_channels * self.model.config.num_attention_heads) //
self.shard_config.tensor_parallel_size,
"self_attention.qkv_hidden_size":
(self.model.config.kv_channels * self.model.config.num_attention_heads * 3) //
self.shard_config.tensor_parallel_size,
"self_attention.core_attention.num_attention_heads_per_partition":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attention.core_attention.hidden_size_per_partition":
self.model.config.kv_channels * self.model.config.num_attention_heads //
self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="self_attention.core_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
])
# optimization configuration
if self.shard_config.enable_fused_normalization:
if not self.model.config.rmsnorm:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm),
SubModuleReplacementDescription(suffix="post_attention_layernorm",
target_module=col_nn.FusedLayerNorm)
],
policy=policy,
target_key=GLMBlock)
if self.model.config.post_layer_norm:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="encoder.final_layernorm",
target_module=col_nn.FusedLayerNorm)
],
policy=policy,
target_key=ChatGLMModel)
return policy
def postprocess(self):
return self.model