From b25d2aa38f2c5a01a0f1190f5a9113fb77037712 Mon Sep 17 00:00:00 2001 From: Lu Guanghua <102669562+Touch-Night@users.noreply.github.com> Date: Mon, 3 Jul 2023 20:35:18 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=B8=85=E9=99=A4=E5=8E=86?= =?UTF-8?q?=E5=8F=B2=E5=AF=B9=E8=AF=9D=E6=8C=89=E9=92=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web_demo2.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/web_demo2.py b/web_demo2.py index 6c66308..8b85b10 100644 --- a/web_demo2.py +++ b/web_demo2.py @@ -2,14 +2,12 @@ from transformers import AutoModel, AutoTokenizer import streamlit as st from streamlit_chat import message - st.set_page_config( page_title="ChatGLM2-6b 演示", - page_icon=":robot:", + page_icon=":robot:" layout='wide' ) - @st.cache_resource def get_model(): tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) @@ -20,19 +18,20 @@ def get_model(): model = model.eval() return tokenizer, model - MAX_TURNS = 20 MAX_BOXES = MAX_TURNS * 2 +#在启动时加载模型 +get_model() -def predict(input, max_length, top_p, temperature, history=None): +def predict(input, history=None): tokenizer, model = get_model() if history is None: history = [] with container: if len(history) > 0: - if len(history)>MAX_BOXES: + if len(history) > MAX_BOXES: history = history[-MAX_TURNS:] for i, (query, response) in enumerate(history): message(query, avatar_style="big-smile", key=str(i) + "_user") @@ -42,12 +41,15 @@ def predict(input, max_length, top_p, temperature, history=None): st.write("AI正在回复:") with st.empty(): for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, - temperature=temperature): + temperature=temperature): query, response = history[-1] st.write(response) return history +#清除对话历史 +def clean(): + st.session_state["state"] = None container = st.container() @@ -72,4 +74,7 @@ if 'state' not in st.session_state: if st.button("发送", key="predict"): with st.spinner("AI正在思考,请稍等........"): # text generation + clean_button = st.button("新对话", on_click=clean) st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"]) + + st.session_state["state"]