diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 90347a984..e38363040 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -116,6 +116,9 @@ _POLICY_LIST = { # Sam "transformers.models.sam.modeling_sam.SamModel": PolicyLocation(file_name="sam", class_name="SamModelPolicy"), + # ChatGLM + "tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm.ChatGLMModel": + PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"), } diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index c17b92c8d..934b99b83 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -1,7 +1,6 @@ from typing import Dict, Union import torch.nn as nn -from ....tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock import colossalai.shardformer.layer as col_nn @@ -9,49 +8,6 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes __all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy'] -class ChatGLMModelPolicy(Policy): - - def config_sanity_check(self): - pass - - def preprocess(self): - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - - return self.model - - def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from ....tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock - - policy = {} - - if self.shard_config.enable_tensor_parallelism: - - policy[GLMBlock] = ModulePolicyDescription( - attribute_replacement = {}, - sub_module_replacement = [ - # SubModuleReplacementDescription( - # suffix = "self_attention.query_key_value", - # target_module = col_nn.Linear1D_Col, - # ), - # SubModuleReplacementDescription( - # suffix = "self_attention.dense", - # target_module = col_nn.Linear1D_Row, - # ) - # SubModuleReplacementDescription( - # suffix = "self_attention.core_attention.attention_dropout", - # target_module = col_nn.DropoutForParallelInput, - # ) - ],) - - - def postprocess(self): - return self.model class ChatGLMModelPolicy(Policy): diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index f05649fcb..2cdf5da2e 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -19,6 +19,7 @@ from colossalai.testing import ( from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward + def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): # check forward org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,