rewrite web_demo.py with gradio.Chatbot and support load model from local directory

pull/202/head
Guo Y.K 2023-03-23 15:59:38 +08:00
parent 5513dd7d2c
commit da09ca4dff
No known key found for this signature in database
GPG Key ID: 315A5D46A979A359
4 changed files with 71 additions and 33 deletions

View File

@ -53,6 +53,8 @@ ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进
```
完整的模型实现可以在 [Hugging Face Hub](https://huggingface.co/THUDM/chatglm-6b) 上查看。如果你从 Hugging Face Hub 上下载checkpoint的速度较慢也可以从[这里](https://cloud.tsinghua.edu.cn/d/fb9f16d6dc8f482596c2/)手动下载。
手动下载的模型放置在 `model` 目录下
### Demo
我们提供了一个基于 [Gradio](https://gradio.app) 的网页版 Demo 和一个命令行 Demo。使用时首先需要下载本仓库

View File

@ -51,6 +51,8 @@ Generate dialogue with the following code
The full model implementation is on [HuggingFace Hub](https://huggingface.co/THUDM/chatglm-6b).
Or you can download model manually and put at `model` directory.
### Demo
We provide a Web demo based on [Gradio](https://gradio.app) and a command line demo in the repo. First clone our repo with:

View File

@ -2,8 +2,10 @@ import os
import platform
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
MODEL_ID = "./model" if os.path.exists('./model') else "THUDM/chatglm-6b"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).half().cuda()
model = model.eval()
os_name = platform.system()

View File

@ -1,45 +1,77 @@
import os
from transformers import AutoModel, AutoTokenizer
import gradio as gr
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
MODEL_ID = "./model" if os.path.exists('./model') else "THUDM/chatglm-6b"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).half().cuda()
model = model.eval()
MAX_TURNS = 20
MAX_BOXES = MAX_TURNS * 2
WELCOME_PROMPT = [[None, "[ChatGLM-6B]Welcome, please input text and press enter"]]
def predict(input, max_length, top_p, temperature, history=None):
if history is None:
history = []
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
temperature=temperature):
updates = []
def predict(input, max_length, top_p, temperature, history):
for _, history in model.stream_chat(
tokenizer, input, history,
max_length=max_length,
top_p=top_p,
temperature=temperature,
):
chatbot = []
for query, response in history:
updates.append(gr.update(visible=True, value="用户:" + query))
updates.append(gr.update(visible=True, value="ChatGLM-6B" + response))
if len(updates) < MAX_BOXES:
updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
yield [history] + updates
chatbot.append([
"[用户]" + query,
"[ChatGLM-6B]" + response
])
if len(chatbot) > MAX_TURNS:
chatbot = chatbot[- MAX_TURNS:]
yield history, WELCOME_PROMPT + chatbot
with gr.Blocks() as demo:
state = gr.State([])
text_boxes = []
for i in range(MAX_BOXES):
if i % 2 == 0:
text_boxes.append(gr.Markdown(visible=False, label="提问:"))
else:
text_boxes.append(gr.Markdown(visible=False, label="回复:"))
with gr.Blocks(title="ChatGLM-6B", css='#main-chatbot { height: 480px; }') as demo:
input_cache = gr.State()
history = gr.State([])
with gr.Row():
with gr.Column(scale=4):
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style(
container=False)
with gr.Column(scale=1):
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
button = gr.Button("Generate")
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
with gr.Column():
pass
with gr.Column():
chatbot = gr.Chatbot(
show_label=False,
elem_id="main-chatbot"
)
input = gr.Textbox(
show_label=False,
placeholder="Input text and press enter",
interactive=True,
)
with gr.Box():
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
with gr.Column():
pass
input.submit(
lambda x: ("", x),
[input],
[input, input_cache]
).then(
predict,
[input_cache, max_length, top_p, temperature, history],
[history, chatbot],
)
demo.load(
lambda: WELCOME_PROMPT,
None,
[chatbot]
)
demo.queue().launch(share=False, inbrowser=True)