diff --git a/web_demo2.py b/web_demo2.py index 6c66308..8b85b10 100644 --- a/web_demo2.py +++ b/web_demo2.py @@ -2,14 +2,12 @@ from transformers import AutoModel, AutoTokenizer import streamlit as st from streamlit_chat import message - st.set_page_config( page_title="ChatGLM2-6b 演示", - page_icon=":robot:", + page_icon=":robot:" layout='wide' ) - @st.cache_resource def get_model(): tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) @@ -20,19 +18,20 @@ def get_model(): model = model.eval() return tokenizer, model - MAX_TURNS = 20 MAX_BOXES = MAX_TURNS * 2 +#在启动时加载模型 +get_model() -def predict(input, max_length, top_p, temperature, history=None): +def predict(input, history=None): tokenizer, model = get_model() if history is None: history = [] with container: if len(history) > 0: - if len(history)>MAX_BOXES: + if len(history) > MAX_BOXES: history = history[-MAX_TURNS:] for i, (query, response) in enumerate(history): message(query, avatar_style="big-smile", key=str(i) + "_user") @@ -42,12 +41,15 @@ def predict(input, max_length, top_p, temperature, history=None): st.write("AI正在回复:") with st.empty(): for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, - temperature=temperature): + temperature=temperature): query, response = history[-1] st.write(response) return history +#清除对话历史 +def clean(): + st.session_state["state"] = None container = st.container() @@ -72,4 +74,7 @@ if 'state' not in st.session_state: if st.button("发送", key="predict"): with st.spinner("AI正在思考,请稍等........"): # text generation + clean_button = st.button("新对话", on_click=clean) st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"]) + + st.session_state["state"]