From 343e7bc7b6126718f10dc57ef3c911958c4b273b Mon Sep 17 00:00:00 2001 From: duzx16 Date: Tue, 28 Mar 2023 19:52:32 +0800 Subject: [PATCH] Fix model path --- cli_demo.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cli_demo.py b/cli_demo.py index fea47fc..1c3ff2b 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -3,14 +3,15 @@ import platform import signal from transformers import AutoTokenizer, AutoModel -tokenizer = AutoTokenizer.from_pretrained("./model", trust_remote_code=True) -model = AutoModel.from_pretrained("./model", trust_remote_code=True).half().cuda() +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model = model.eval() os_name = platform.system() clear_command = 'cls' if os_name == 'Windows' else 'clear' stop_stream = False + def build_prompt(history): prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序" for query, response in history: @@ -18,10 +19,12 @@ def build_prompt(history): prompt += f"\n\nChatGLM-6B:{response}" return prompt + def signal_handler(signal, frame): global stop_stream stop_stream = True + def main(): history = [] global stop_stream @@ -45,7 +48,7 @@ def main(): if count % 8 == 0: os.system(clear_command) print(build_prompt(history), flush=True) - signal.signal(signal.SIGINT,signal_handler) + signal.signal(signal.SIGINT, signal_handler) os.system(clear_command) print(build_prompt(history), flush=True)