添加清除历史对话按钮

pull/165/head
Lu Guanghua 2023-07-03 20:35:18 +08:00 committed by GitHub
parent b519f9a092
commit b25d2aa38f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 7 deletions

View File

@ -2,14 +2,12 @@ from transformers import AutoModel, AutoTokenizer
import streamlit as st
from streamlit_chat import message
st.set_page_config(
page_title="ChatGLM2-6b 演示",
page_icon=":robot:",
page_icon=":robot:"
layout='wide'
)
@st.cache_resource
def get_model():
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
@ -20,19 +18,20 @@ def get_model():
model = model.eval()
return tokenizer, model
MAX_TURNS = 20
MAX_BOXES = MAX_TURNS * 2
#在启动时加载模型
get_model()
def predict(input, max_length, top_p, temperature, history=None):
def predict(input, history=None):
tokenizer, model = get_model()
if history is None:
history = []
with container:
if len(history) > 0:
if len(history)>MAX_BOXES:
if len(history) > MAX_BOXES:
history = history[-MAX_TURNS:]
for i, (query, response) in enumerate(history):
message(query, avatar_style="big-smile", key=str(i) + "_user")
@ -42,12 +41,15 @@ def predict(input, max_length, top_p, temperature, history=None):
st.write("AI正在回复:")
with st.empty():
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
temperature=temperature):
temperature=temperature):
query, response = history[-1]
st.write(response)
return history
#清除对话历史
def clean():
st.session_state["state"] = None
container = st.container()
@ -72,4 +74,7 @@ if 'state' not in st.session_state:
if st.button("发送", key="predict"):
with st.spinner("AI正在思考请稍等........"):
# text generation
clean_button = st.button("新对话", on_click=clean)
st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"])
st.session_state["state"]