添加清除历史对话按钮

pull/165/head
Lu Guanghua 2023-07-03 20:35:18 +08:00 committed by GitHub
parent b519f9a092
commit b25d2aa38f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 7 deletions

View File

@ -2,14 +2,12 @@ from transformers import AutoModel, AutoTokenizer
import streamlit as st
from streamlit_chat import message
st.set_page_config(
page_title="ChatGLM2-6b 演示",
page_icon=":robot:",
page_icon=":robot:"
layout='wide'
)
@st.cache_resource
def get_model():
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
@ -20,12 +18,13 @@ def get_model():
model = model.eval()
return tokenizer, model
MAX_TURNS = 20
MAX_BOXES = MAX_TURNS * 2
#在启动时加载模型
get_model()
def predict(input, max_length, top_p, temperature, history=None):
def predict(input, history=None):
tokenizer, model = get_model()
if history is None:
history = []
@ -48,6 +47,9 @@ def predict(input, max_length, top_p, temperature, history=None):
return history
#清除对话历史
def clean():
st.session_state["state"] = None
container = st.container()
@ -72,4 +74,7 @@ if 'state' not in st.session_state:
if st.button("发送", key="predict"):
with st.spinner("AI正在思考请稍等........"):
# text generation
clean_button = st.button("新对话", on_click=clean)
st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"])
st.session_state["state"]