rewrite web_demo.py with gradio.Chatbot and support load model from local directory

pull/202/head
Guo Y.K 2023-03-23 15:59:38 +08:00
parent 5513dd7d2c
commit da09ca4dff
No known key found for this signature in database
GPG Key ID: 315A5D46A979A359
4 changed files with 71 additions and 33 deletions

View File

@ -53,6 +53,8 @@ ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进
``` ```
完整的模型实现可以在 [Hugging Face Hub](https://huggingface.co/THUDM/chatglm-6b) 上查看。如果你从 Hugging Face Hub 上下载checkpoint的速度较慢也可以从[这里](https://cloud.tsinghua.edu.cn/d/fb9f16d6dc8f482596c2/)手动下载。 完整的模型实现可以在 [Hugging Face Hub](https://huggingface.co/THUDM/chatglm-6b) 上查看。如果你从 Hugging Face Hub 上下载checkpoint的速度较慢也可以从[这里](https://cloud.tsinghua.edu.cn/d/fb9f16d6dc8f482596c2/)手动下载。
手动下载的模型放置在 `model` 目录下
### Demo ### Demo
我们提供了一个基于 [Gradio](https://gradio.app) 的网页版 Demo 和一个命令行 Demo。使用时首先需要下载本仓库 我们提供了一个基于 [Gradio](https://gradio.app) 的网页版 Demo 和一个命令行 Demo。使用时首先需要下载本仓库

View File

@ -51,6 +51,8 @@ Generate dialogue with the following code
The full model implementation is on [HuggingFace Hub](https://huggingface.co/THUDM/chatglm-6b). The full model implementation is on [HuggingFace Hub](https://huggingface.co/THUDM/chatglm-6b).
Or you can download model manually and put at `model` directory.
### Demo ### Demo
We provide a Web demo based on [Gradio](https://gradio.app) and a command line demo in the repo. First clone our repo with: We provide a Web demo based on [Gradio](https://gradio.app) and a command line demo in the repo. First clone our repo with:

View File

@ -2,8 +2,10 @@ import os
import platform import platform
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) MODEL_ID = "./model" if os.path.exists('./model') else "THUDM/chatglm-6b"
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).half().cuda()
model = model.eval() model = model.eval()
os_name = platform.system() os_name = platform.system()

View File

@ -1,45 +1,77 @@
import os
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
import gradio as gr import gradio as gr
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) MODEL_ID = "./model" if os.path.exists('./model') else "THUDM/chatglm-6b"
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).half().cuda()
model = model.eval() model = model.eval()
MAX_TURNS = 20 MAX_TURNS = 20
MAX_BOXES = MAX_TURNS * 2
WELCOME_PROMPT = [[None, "[ChatGLM-6B]Welcome, please input text and press enter"]]
def predict(input, max_length, top_p, temperature, history=None): def predict(input, max_length, top_p, temperature, history):
if history is None: for _, history in model.stream_chat(
history = [] tokenizer, input, history,
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, max_length=max_length,
temperature=temperature): top_p=top_p,
updates = [] temperature=temperature,
):
chatbot = []
for query, response in history: for query, response in history:
updates.append(gr.update(visible=True, value="用户:" + query)) chatbot.append([
updates.append(gr.update(visible=True, value="ChatGLM-6B" + response)) "[用户]" + query,
if len(updates) < MAX_BOXES: "[ChatGLM-6B]" + response
updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates)) ])
yield [history] + updates
if len(chatbot) > MAX_TURNS:
chatbot = chatbot[- MAX_TURNS:]
yield history, WELCOME_PROMPT + chatbot
with gr.Blocks() as demo: with gr.Blocks(title="ChatGLM-6B", css='#main-chatbot { height: 480px; }') as demo:
state = gr.State([]) input_cache = gr.State()
text_boxes = [] history = gr.State([])
for i in range(MAX_BOXES):
if i % 2 == 0:
text_boxes.append(gr.Markdown(visible=False, label="提问:"))
else:
text_boxes.append(gr.Markdown(visible=False, label="回复:"))
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column():
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style( pass
container=False) with gr.Column():
with gr.Column(scale=1): chatbot = gr.Chatbot(
show_label=False,
elem_id="main-chatbot"
)
input = gr.Textbox(
show_label=False,
placeholder="Input text and press enter",
interactive=True,
)
with gr.Box():
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
button = gr.Button("Generate") with gr.Column():
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes) pass
input.submit(
lambda x: ("", x),
[input],
[input, input_cache]
).then(
predict,
[input_cache, max_length, top_p, temperature, history],
[history, chatbot],
)
demo.load(
lambda: WELCOME_PROMPT,
None,
[chatbot]
)
demo.queue().launch(share=False, inbrowser=True) demo.queue().launch(share=False, inbrowser=True)