diff --git a/web_demo2.py b/web_demo2.py index 203cbdc..2c91ed3 100644 --- a/web_demo2.py +++ b/web_demo2.py @@ -45,19 +45,14 @@ for i, (query, response) in enumerate(st.session_state.history): st.markdown(query) with st.chat_message(name="assistant", avatar="assistant"): st.markdown(response) -with st.chat_message(name="user", avatar="user"): - input_placeholder = st.empty() -with st.chat_message(name="assistant", avatar="assistant"): - message_placeholder = st.empty() -prompt_text = st.text_area(label="用户命令输入", - height=100, - placeholder="请在这儿输入您的命令") +prompt_text = st.chat_input(placeholder="请在这儿输入您的命令") +if prompt_text: + with st.chat_message(name="user", avatar="user"): + st.markdown(prompt_text) + with st.chat_message(name="assistant", avatar="assistant"): + message_placeholder = st.empty() -button = st.button("发送", key="predict") - -if button: - input_placeholder.markdown(prompt_text) history, past_key_values = st.session_state.history, st.session_state.past_key_values for response, history, past_key_values in model.stream_chat(tokenizer, prompt_text, history, past_key_values=past_key_values,