From f232ff33afe5aea518088fc568d8cf0f32b5c0de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BD=95=E9=87=8F?= Date: Thu, 18 May 2023 15:44:48 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BC=98=E5=8C=96=E6=8E=A7=E5=88=B6?= =?UTF-8?q?=E5=8F=B0=E9=97=AA=E7=83=81=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cli_demo.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/cli_demo.py b/cli_demo.py index 3559840..f92d3d8 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -29,6 +29,7 @@ def signal_handler(signal, frame): def main(): history = [] global stop_stream + signal.signal(signal.SIGINT, signal_handler) print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") while True: query = input("\n用户:") @@ -39,20 +40,16 @@ def main(): os.system(clear_command) print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") continue - count = 0 + prev_response = "" + print("ChatGLM-6B:", end="", flush=True) for response, history in model.stream_chat(tokenizer, query, history=history): if stop_stream: stop_stream = False break else: - count += 1 - if count % 8 == 0: - os.system(clear_command) - print(build_prompt(history), flush=True) - signal.signal(signal.SIGINT, signal_handler) - os.system(clear_command) - print(build_prompt(history), flush=True) - + print(response[len(prev_response):], end="", flush=True) + prev_response = response + print("\n", end="", flush=True) if __name__ == "__main__": main()