[shardformer] polish chatglm code

pull/4445/head
klhhhhh 1 year ago committed by Hongxin Liu
parent 8620009dd7
commit 1a29e8fc29

@ -116,6 +116,9 @@ _POLICY_LIST = {
# Sam
"transformers.models.sam.modeling_sam.SamModel":
PolicyLocation(file_name="sam", class_name="SamModelPolicy"),
# ChatGLM
"tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm.ChatGLMModel":
PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"),
}

@ -1,7 +1,6 @@
from typing import Dict, Union
import torch.nn as nn
from ....tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock
import colossalai.shardformer.layer as col_nn
@ -9,49 +8,6 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes
__all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy']
class ChatGLMModelPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from ....tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock
policy = {}
if self.shard_config.enable_tensor_parallelism:
policy[GLMBlock] = ModulePolicyDescription(
attribute_replacement = {},
sub_module_replacement = [
# SubModuleReplacementDescription(
# suffix = "self_attention.query_key_value",
# target_module = col_nn.Linear1D_Col,
# ),
# SubModuleReplacementDescription(
# suffix = "self_attention.dense",
# target_module = col_nn.Linear1D_Row,
# )
# SubModuleReplacementDescription(
# suffix = "self_attention.core_attention.attention_dropout",
# target_module = col_nn.DropoutForParallelInput,
# )
],)
def postprocess(self):
return self.model
class ChatGLMModelPolicy(Policy):

@ -19,6 +19,7 @@ from colossalai.testing import (
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,

Loading…
Cancel
Save