From bf39dac0670d4cbb734aaa4664001f1863cecf14 Mon Sep 17 00:00:00 2001 From: holk-h Date: Fri, 24 Mar 2023 18:34:09 +0800 Subject: [PATCH] Support stream out interruption by using Ctrl+C --- cli_demo.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/cli_demo.py b/cli_demo.py index 8a043fb..fea47fc 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -1,14 +1,15 @@ import os import platform +import signal from transformers import AutoTokenizer, AutoModel -tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) -model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +tokenizer = AutoTokenizer.from_pretrained("./model", trust_remote_code=True) +model = AutoModel.from_pretrained("./model", trust_remote_code=True).half().cuda() model = model.eval() os_name = platform.system() clear_command = 'cls' if os_name == 'Windows' else 'clear' - +stop_stream = False def build_prompt(history): prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序" @@ -17,9 +18,13 @@ def build_prompt(history): prompt += f"\n\nChatGLM-6B:{response}" return prompt +def signal_handler(signal, frame): + global stop_stream + stop_stream = True def main(): history = [] + global stop_stream print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") while True: query = input("\n用户:") @@ -32,10 +37,15 @@ def main(): continue count = 0 for response, history in model.stream_chat(tokenizer, query, history=history): - count += 1 - if count % 8 == 0: - os.system(clear_command) - print(build_prompt(history), flush=True) + if stop_stream: + stop_stream = False + break + else: + count += 1 + if count % 8 == 0: + os.system(clear_command) + print(build_prompt(history), flush=True) + signal.signal(signal.SIGINT,signal_handler) os.system(clear_command) print(build_prompt(history), flush=True)