ChatGLM-6B/web_demo.py

53 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from transformers import AutoModel, AutoTokenizer
import gradio as gr
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
#GPU 部署
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float()
# 按需修改,目前只支持 4/8 bit 量化
#model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().quantize(4).cuda()
#如果你没有 GPU 硬件的话,也可以在 CPU 上进行推理,但是推理速度会更慢。使用方法如下(需要大概 32GB 内存)
#model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float()
#如果你的内存不足,可以直接加载量化后的模型:
#model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4",trust_remote_code=True).float()
model = model.eval()
MAX_TURNS = 20
MAX_BOXES = MAX_TURNS * 2
def predict(input, max_length, top_p, temperature, history=None):
if history is None:
history = []
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
temperature=temperature):
updates = []
for query, response in history:
updates.append(gr.update(visible=True, value="用户:" + query))
updates.append(gr.update(visible=True, value="ChatGLM-6B" + response))
if len(updates) < MAX_BOXES:
updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
yield [history] + updates
with gr.Blocks() as demo:
state = gr.State([])
text_boxes = []
for i in range(MAX_BOXES):
if i % 2 == 0:
text_boxes.append(gr.Markdown(visible=False, label="提问:"))
else:
text_boxes.append(gr.Markdown(visible=False, label="回复:"))
with gr.Row():
with gr.Column(scale=4):
txt = gr.Textbox(show_label=False, placeholder="输入文本并按Enter键", lines=11).style(
container=False)
with gr.Column(scale=1):
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="最大长度", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="氛围", interactive=True)
button = gr.Button("生成")
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
demo.queue().launch(share=False, inbrowser=True)