diff --git a/cli_demo_gpus.py b/cli_demo_gpus.py new file mode 100644 index 0000000..56c8d83 --- /dev/null +++ b/cli_demo_gpus.py @@ -0,0 +1,59 @@ +import os +import platform +import signal +from transformers import AutoTokenizer, AutoModel +from utils import load_model_on_gpus + +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2) +# model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +model = model.eval() + +os_name = platform.system() +clear_command = 'cls' if os_name == 'Windows' else 'clear' +stop_stream = False + + +def build_prompt(history): + prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序" + for query, response in history: + prompt += f"\n\n用户:{query}" + prompt += f"\n\nChatGLM-6B:{response}" + return prompt + + +def signal_handler(signal, frame): + global stop_stream + stop_stream = True + + +def main(): + history = [] + global stop_stream + print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + while True: + query = input("\n用户:") + if query.strip() == "stop": + break + if query.strip() == "clear": + history = [] + os.system(clear_command) + print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + continue + count = 0 + for response, history in model.stream_chat(tokenizer, query, history=history): + if stop_stream: + stop_stream = False + break + else: + count += 1 + if count % 8 == 0: + os.system(clear_command) + print(build_prompt(history), flush=True) + signal.signal(signal.SIGINT, signal_handler) + os.system(clear_command) + print(build_prompt(history), flush=True) + + +if __name__ == "__main__": + main() diff --git a/web_demo.py b/web_demo.py index 97ea622..8f221bf 100644 --- a/web_demo.py +++ b/web_demo.py @@ -1,9 +1,12 @@ -from transformers import AutoModel, AutoTokenizer import gradio as gr import mdtex2html +from transformers import AutoModel, AutoTokenizer +from utils import load_model_on_gpus + tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) -model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +# model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2) model = model.eval() """Override Chatbot.postprocess""" @@ -60,7 +63,7 @@ def predict(input, chatbot, max_length, top_p, temperature, history): chatbot.append((parse_text(input), "")) for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, temperature=temperature): - chatbot[-1] = (parse_text(input), parse_text(response)) + chatbot[-1] = (parse_text(input), parse_text(response)) yield chatbot, history @@ -74,21 +77,24 @@ def reset_state(): with gr.Blocks() as demo: - gr.HTML("""

ChatGLM

""") + gr.HTML("""

CodeLab

""") chatbot = gr.Chatbot() with gr.Row(): with gr.Column(scale=4): with gr.Column(scale=12): - user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( + user_input = gr.Textbox(show_label=False, placeholder="输入聊天内容", lines=10).style( container=False) with gr.Column(min_width=32, scale=1): - submitBtn = gr.Button("Submit", variant="primary") + submitBtn = gr.Button("发送", variant="primary") with gr.Column(scale=1): - emptyBtn = gr.Button("Clear History") - max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) - top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) - temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) + emptyBtn = gr.Button("清除历史记录") + max_length = gr.Slider( + 0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) + top_p = gr.Slider(0, 1, value=0.7, step=0.01, + label="Top P", interactive=True) + temperature = gr.Slider( + 0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) history = gr.State([]) @@ -98,4 +104,4 @@ with gr.Blocks() as demo: emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) -demo.queue().launch(share=False, inbrowser=True) +demo.queue().launch(share=False, inbrowser=False, server_port=11001)