From da09ca4dff6a98cfc254831283e5110317e01352 Mon Sep 17 00:00:00 2001 From: "Guo Y.K" Date: Thu, 23 Mar 2023 15:59:38 +0800 Subject: [PATCH] rewrite web_demo.py with gradio.Chatbot and support load model from local directory --- README.md | 2 ++ README_en.md | 2 ++ cli_demo.py | 6 ++-- web_demo.py | 94 +++++++++++++++++++++++++++++++++++----------------- 4 files changed, 71 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 734ce70..262ba07 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,8 @@ ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进 ``` 完整的模型实现可以在 [Hugging Face Hub](https://huggingface.co/THUDM/chatglm-6b) 上查看。如果你从 Hugging Face Hub 上下载checkpoint的速度较慢,也可以从[这里](https://cloud.tsinghua.edu.cn/d/fb9f16d6dc8f482596c2/)手动下载。 +手动下载的模型放置在 `model` 目录下 + ### Demo 我们提供了一个基于 [Gradio](https://gradio.app) 的网页版 Demo 和一个命令行 Demo。使用时首先需要下载本仓库: diff --git a/README_en.md b/README_en.md index b4dcfe8..721d9e4 100644 --- a/README_en.md +++ b/README_en.md @@ -51,6 +51,8 @@ Generate dialogue with the following code The full model implementation is on [HuggingFace Hub](https://huggingface.co/THUDM/chatglm-6b). +Or you can download model manually and put at `model` directory. + ### Demo We provide a Web demo based on [Gradio](https://gradio.app) and a command line demo in the repo. First clone our repo with: diff --git a/cli_demo.py b/cli_demo.py index 8a043fb..e093c24 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -2,8 +2,10 @@ import os import platform from transformers import AutoTokenizer, AutoModel -tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) -model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +MODEL_ID = "./model" if os.path.exists('./model') else "THUDM/chatglm-6b" + +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) +model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).half().cuda() model = model.eval() os_name = platform.system() diff --git a/web_demo.py b/web_demo.py index 88a6dc8..969fdc4 100644 --- a/web_demo.py +++ b/web_demo.py @@ -1,45 +1,77 @@ +import os from transformers import AutoModel, AutoTokenizer import gradio as gr -tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) -model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +MODEL_ID = "./model" if os.path.exists('./model') else "THUDM/chatglm-6b" + +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) +model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).half().cuda() model = model.eval() MAX_TURNS = 20 -MAX_BOXES = MAX_TURNS * 2 + +WELCOME_PROMPT = [[None, "[ChatGLM-6B]:Welcome, please input text and press enter"]] -def predict(input, max_length, top_p, temperature, history=None): - if history is None: - history = [] - for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, - temperature=temperature): - updates = [] +def predict(input, max_length, top_p, temperature, history): + for _, history in model.stream_chat( + tokenizer, input, history, + max_length=max_length, + top_p=top_p, + temperature=temperature, + ): + chatbot = [] + for query, response in history: - updates.append(gr.update(visible=True, value="用户:" + query)) - updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response)) - if len(updates) < MAX_BOXES: - updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates)) - yield [history] + updates + chatbot.append([ + "[用户]:" + query, + "[ChatGLM-6B]:" + response + ]) + + if len(chatbot) > MAX_TURNS: + chatbot = chatbot[- MAX_TURNS:] + + yield history, WELCOME_PROMPT + chatbot -with gr.Blocks() as demo: - state = gr.State([]) - text_boxes = [] - for i in range(MAX_BOXES): - if i % 2 == 0: - text_boxes.append(gr.Markdown(visible=False, label="提问:")) - else: - text_boxes.append(gr.Markdown(visible=False, label="回复:")) +with gr.Blocks(title="ChatGLM-6B", css='#main-chatbot { height: 480px; }') as demo: + input_cache = gr.State() + history = gr.State([]) with gr.Row(): - with gr.Column(scale=4): - txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style( - container=False) - with gr.Column(scale=1): - max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) - top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) - temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) - button = gr.Button("Generate") - button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes) + with gr.Column(): + pass + with gr.Column(): + chatbot = gr.Chatbot( + show_label=False, + elem_id="main-chatbot" + ) + input = gr.Textbox( + show_label=False, + placeholder="Input text and press enter", + interactive=True, + ) + with gr.Box(): + max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) + top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) + temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) + with gr.Column(): + pass + + input.submit( + lambda x: ("", x), + [input], + [input, input_cache] + ).then( + predict, + [input_cache, max_length, top_p, temperature, history], + [history, chatbot], + ) + + demo.load( + lambda: WELCOME_PROMPT, + None, + [chatbot] + ) + demo.queue().launch(share=False, inbrowser=True)