diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index 934b99b83..46aa3b52a 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -90,7 +90,31 @@ class ChatGLMModelPolicy(Policy): policy=policy, target_key=ChatGLMModel) + else: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm), + SubModuleReplacementDescription(suffix="post_attention_layernorm", + target_module=col_nn.FusedRMSNorm) + ], + policy=policy, + target_key=GLMBlock) + + if self.model.config.post_layer_norm: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="encoder.final_layernorm", + target_module=col_nn.FusedRMSNorm) + ], + policy=policy, + target_key=ChatGLMModel) + return policy def postprocess(self): return self.model + + +class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy): + + def module_policy(self): + policy = super().module_policy() + return policy diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 96f27de2a..1feb11ffc 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -23,7 +23,7 @@ class ViTPolicy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer + from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel policy = {} diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py index 1408babed..04e73a832 100644 --- a/tests/kit/model_zoo/transformers/chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm.py @@ -3,7 +3,7 @@ import transformers from ..registry import ModelAttribute, model_zoo from .chatglm2_6b.configuration_chatglm import ChatGLMConfig -from .chatglm2_6b.modeling_chatglm import ChatGLMModel +from .chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel # ================================ # Register single-sentence ChatGLM @@ -21,7 +21,7 @@ output_transform_fn = lambda x: x # define loss function loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.mean() -loss_fn = lambda x: x.loss +loss_fn = lambda x: x.logits.mean() config = ChatGLMConfig(num_layers=1, padded_vocab_size=65024, hidden_size=64, @@ -36,3 +36,10 @@ model_zoo.register(name='transformers_chatglm', output_transform_fn=output_transform_fn, loss_fn=loss_fn_for_chatglm_model, model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name="transformers_chatglm_for_conditional_generation", + model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index 2cdf5da2e..a0fa4bd82 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -7,7 +7,7 @@ import torch import colossalai from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer.policies.chatglm import ChatGLMModelPolicy +from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, @@ -85,6 +85,8 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism): shard_former = ShardFormer(shard_config=shard_config) if name == "transformers_chatglm": sharded_model = shard_former.optimize(model_copy, ChatGLMModelPolicy()).cuda() + else: + sharded_model = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy()).cuda() check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache()