From 1887c72e36f156191521e4ec3642339927c3c377 Mon Sep 17 00:00:00 2001 From: openmartin Date: Wed, 5 Apr 2023 22:42:22 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BD=BF=E7=94=A8=20gradio=20chatbox=20?= =?UTF-8?q?=E6=9D=A5=E5=B1=95=E7=A4=BA=E8=81=8A=E5=A4=A9=E8=AE=B0=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web_demo.py | 59 +++++++++++++++++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/web_demo.py b/web_demo.py index 88a6dc8..6bd27ce 100644 --- a/web_demo.py +++ b/web_demo.py @@ -9,37 +9,48 @@ MAX_TURNS = 20 MAX_BOXES = MAX_TURNS * 2 -def predict(input, max_length, top_p, temperature, history=None): - if history is None: - history = [] - for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, +def btn_is_clickable(txt): + if txt is not None and txt.strip() != '': + return gr.update(interactive=True) + else: + return gr.update(interactive=False) + + +def user_message(user_message, history): + return "", history + [[user_message, None]] + + +def predict(input, max_length, top_p, temperature, history): + if len(history) > MAX_TURNS: + history = history[-20:] + for response, const_history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, temperature=temperature): - updates = [] - for query, response in history: - updates.append(gr.update(visible=True, value="用户:" + query)) - updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response)) - if len(updates) < MAX_BOXES: - updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates)) - yield [history] + updates + history[-1][1] = response + yield history with gr.Blocks() as demo: - state = gr.State([]) - text_boxes = [] - for i in range(MAX_BOXES): - if i % 2 == 0: - text_boxes.append(gr.Markdown(visible=False, label="提问:")) - else: - text_boxes.append(gr.Markdown(visible=False, label="回复:")) - with gr.Row(): - with gr.Column(scale=4): - txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style( - container=False) with gr.Column(scale=1): + gr.Markdown("https://github.com/THUDM/ChatGLM-6B") max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) - button = gr.Button("Generate") - button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes) + with gr.Column(scale=4): + chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750) + with gr.Row(): + with gr.Column(scale=4): + txt = gr.Textbox( + show_label=False, + placeholder="有问题就会有答案" + ).style(container=False) + with gr.Column(scale=1, min_width=0): + btn = gr.Button("发送", interactive=False) + # 控制按钮是否可以点击 + txt.change(btn_is_clickable, txt, btn) + # 发送消息 + btn.click(user_message, [txt, chatbot], [txt, chatbot], queue=False).then( + predict, [txt, max_length, top_p, temperature, chatbot], chatbot + ) + demo.queue().launch(share=False, inbrowser=True)