diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index 934b99b83..c17b92c8d 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -1,6 +1,7 @@ from typing import Dict, Union import torch.nn as nn +from ....tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock import colossalai.shardformer.layer as col_nn @@ -8,6 +9,49 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes __all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy'] +class ChatGLMModelPolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from ....tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + + policy[GLMBlock] = ModulePolicyDescription( + attribute_replacement = {}, + sub_module_replacement = [ + # SubModuleReplacementDescription( + # suffix = "self_attention.query_key_value", + # target_module = col_nn.Linear1D_Col, + # ), + # SubModuleReplacementDescription( + # suffix = "self_attention.dense", + # target_module = col_nn.Linear1D_Row, + # ) + # SubModuleReplacementDescription( + # suffix = "self_attention.core_attention.attention_dropout", + # target_module = col_nn.DropoutForParallelInput, + # ) + ],) + + + def postprocess(self): + return self.model class ChatGLMModelPolicy(Policy): diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index 2cdf5da2e..f05649fcb 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -19,7 +19,6 @@ from colossalai.testing import ( from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward - def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): # check forward org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,