From 2ed89f3898c3fd4656ce7bbf744f4e8cd087e128 Mon Sep 17 00:00:00 2001 From: duzx16 Date: Sun, 19 Mar 2023 14:33:05 +0800 Subject: [PATCH] Add support for streaming output --- cli_demo.py | 48 ++++++++++++++++++++++++++++++++++-------------- web_demo.py | 18 +++++++++--------- 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/cli_demo.py b/cli_demo.py index d87f707..768df90 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -7,18 +7,38 @@ model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).ha model = model.eval() os_name = platform.system() +clear_command = 'cls' if os_name == 'Windows' else 'clear' -history = [] -print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") -while True: - query = input("\n用户:") - if query == "stop": - break - if query == "clear": - history = [] - command = 'cls' if os_name == 'Windows' else 'clear' - os.system(command) - print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") - continue - response, history = model.chat(tokenizer, query, history=history) - print(f"ChatGLM-6B:{response}") + +def build_prompt(history): + prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序" + for query, response in history: + prompt += f"\n用户:{query}" + prompt += f"\nChatGLM-6B:{response}" + return prompt + + +def main(): + history = [] + print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + while True: + query = input("\n用户:") + if query == "stop": + break + if query == "clear": + history = [] + os.system(clear_command) + print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + continue + count = 0 + for response, history in model.stream_chat(tokenizer, query, history=history): + count += 1 + if count % 8 == 0: + os.system(clear_command) + print(build_prompt(history), flush=True) + os.system(clear_command) + print(build_prompt(history), flush=True) + + +if __name__ == "__main__": + main() diff --git a/web_demo.py b/web_demo.py index 03cb319..39709a7 100644 --- a/web_demo.py +++ b/web_demo.py @@ -12,15 +12,15 @@ MAX_BOXES = MAX_TURNS * 2 def predict(input, max_length, top_p, temperature, history=None): if history is None: history = [] - response, history = model.chat(tokenizer, input, history, max_length=max_length, top_p=top_p, - temperature=temperature) - updates = [] - for query, response in history: - updates.append(gr.update(visible=True, value="用户:" + query)) - updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response)) - if len(updates) < MAX_BOXES: - updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates)) - return [history] + updates + for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, + temperature=temperature): + updates = [] + for query, response in history: + updates.append(gr.update(visible=True, value="用户:" + query)) + updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response)) + if len(updates) < MAX_BOXES: + updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates)) + yield [history] + updates with gr.Blocks() as demo: