diff --git a/web_demo2.py b/web_demo2.py index 9ae6b26..4e1f0e4 100644 --- a/web_demo2.py +++ b/web_demo2.py @@ -25,20 +25,31 @@ def predict(input, history=None): tokenizer, model = get_model() if history is None: history = [] - response, history = model.chat(tokenizer, input, history) - for i, (query, response) in enumerate(history): - message(query, avatar_style="big-smile", key=str(i) + "_user") - message(response, avatar_style="bottts", key=str(i)) + with container: + if len(history) > 0: + for i, (query, response) in enumerate(history): + message(query, avatar_style="big-smile", key=str(i) + "_user") + message(response, avatar_style="bottts", key=str(i)) + + message(input, avatar_style="big-smile", key=str(len(history)) + "_user") + st.write("AI正在回复:") + with st.empty(): + for response, history in model.stream_chat(tokenizer, input, history): + query, response = history[-1] + st.write(response) return history +container = st.container() + # create a prompt text for the text generation prompt_text = st.text_area(label="用户命令输入", height = 100, placeholder="请在这儿输入您的命令") + if 'state' not in st.session_state: st.session_state['state'] = [] @@ -46,5 +57,3 @@ if st.button("发送", key="predict"): with st.spinner("AI正在思考,请稍等........"): # text generation st.session_state["state"] = predict(prompt_text, st.session_state["state"]) - - st.balloons() \ No newline at end of file