diff --git a/cli_demo.py b/cli_demo.py index 8a043fb..ca9f893 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -10,34 +10,34 @@ os_name = platform.system() clear_command = 'cls' if os_name == 'Windows' else 'clear' -def build_prompt(history): - prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序" - for query, response in history: - prompt += f"\n\n用户:{query}" - prompt += f"\n\nChatGLM-6B:{response}" - return prompt - +def build_prompt(history, prev_resp, count): + cur_resp = history[count][1][len(prev_resp[0]):] + d = cur_resp.encode('unicode_escape') + if b'\\ufff' in d: + return + print(cur_resp, end='', flush=True) + prev_resp[0] += cur_resp def main(): history = [] + os.system(clear_command) print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + count = 0 while True: - query = input("\n用户:") + query = input("\n\n用户:") if query == "stop": break if query == "clear": history = [] + count = 0 os.system(clear_command) print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") continue - count = 0 + print('\nChat:', end='') + prev_resp = [""] for response, history in model.stream_chat(tokenizer, query, history=history): - count += 1 - if count % 8 == 0: - os.system(clear_command) - print(build_prompt(history), flush=True) - os.system(clear_command) - print(build_prompt(history), flush=True) + build_prompt(history, prev_resp, count) + count += 1 if __name__ == "__main__":