diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py index f443553bb..9a7cf34c1 100644 --- a/tests/kit/model_zoo/transformers/chatglm2.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -33,22 +33,6 @@ loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss( ) loss_fn = lambda x: x["loss"] -config = AutoConfig.from_pretrained( - "THUDM/chatglm2-6b", - trust_remote_code=True, - num_layers=2, - padded_vocab_size=65024, - hidden_size=64, - ffn_hidden_size=214, - num_attention_heads=8, - kv_channels=16, - rmsnorm=True, - original_rope=True, - use_cache=True, - multi_query_attention=False, - torch_dtype=torch.float32, -) - infer_config = AutoConfig.from_pretrained( "THUDM/chatglm2-6b", @@ -68,6 +52,21 @@ infer_config = AutoConfig.from_pretrained( def init_chatglm(): + config = AutoConfig.from_pretrained( + "THUDM/chatglm2-6b", + trust_remote_code=True, + num_layers=2, + padded_vocab_size=65024, + hidden_size=64, + ffn_hidden_size=214, + num_attention_heads=8, + kv_channels=16, + rmsnorm=True, + original_rope=True, + use_cache=True, + multi_query_attention=False, + torch_dtype=torch.float32, + ) model = AutoModelForCausalLM.from_config(config, empty_init=False, trust_remote_code=True) for m in model.modules(): if m.__class__.__name__ == "RMSNorm":