diff --git a/colossalai/shardformer/modeling/chatglm.py b/colossalai/shardformer/modeling/chatglm2.py similarity index 100% rename from colossalai/shardformer/modeling/chatglm.py rename to colossalai/shardformer/modeling/chatglm2.py diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index eec339c02..2fe49f0d5 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -125,9 +125,9 @@ _POLICY_LIST = { # ChatGLM "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": - PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"), + PolicyLocation(file_name="chatglm2", class_name="ChatGLMModelPolicy"), "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": - PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"), + PolicyLocation(file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"), } diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm2.py similarity index 98% rename from colossalai/shardformer/policies/chatglm.py rename to colossalai/shardformer/policies/chatglm2.py index e6b458936..a15aa856d 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -7,7 +7,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast import colossalai.shardformer.layer as col_nn from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.modeling.chatglm import ChatGLMPipelineForwards +from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ChatGLMForConditionalGeneration, @@ -15,7 +15,7 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( GLMBlock, ) -from ..modeling.chatglm import get_flash_core_attention_forward, get_jit_fused_glm_block_forward +from ..modeling.chatglm2 import get_flash_core_attention_forward, get_jit_fused_glm_block_forward from ..modeling.jit import get_jit_fused_dropout_add_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 823ca032f..2a492361b 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -2,7 +2,7 @@ from .albert import * from .bert import * from .blip2 import * from .bloom import * -from .chatglm import * +from .chatglm2 import * from .gpt import * from .llama import * from .opt import * diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm2.py similarity index 100% rename from tests/kit/model_zoo/transformers/chatglm.py rename to tests/kit/model_zoo/transformers/chatglm2.py diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py similarity index 100% rename from tests/test_shardformer/test_model/test_shard_chatglm.py rename to tests/test_shardformer/test_model/test_shard_chatglm2.py