|
|
|
@ -11,8 +11,8 @@ st.set_page_config(
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
|
|
def get_model(): |
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("/THUDM/chatglm-6b", trust_remote_code=True) |
|
|
|
|
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() |
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("/data/chatglm-6b", trust_remote_code=True) |
|
|
|
|
model = AutoModel.from_pretrained("/data/chatglm-6b", trust_remote_code=True).half().cuda() |
|
|
|
|
model = model.eval() |
|
|
|
|
return tokenizer, model |
|
|
|
|
|
|
|
|
@ -25,26 +25,31 @@ def predict(input, history=None):
|
|
|
|
|
tokenizer, model = get_model() |
|
|
|
|
if history is None: |
|
|
|
|
history = [] |
|
|
|
|
response, history = model.chat(tokenizer, input, history) |
|
|
|
|
|
|
|
|
|
#updates = [] |
|
|
|
|
for i, (query, response) in enumerate(history): |
|
|
|
|
#updates.append("用户:" + query) |
|
|
|
|
message(query, avatar_style="big-smile", key=str(i) + "_user") |
|
|
|
|
#updates.append("ChatGLM-6B:" + response) |
|
|
|
|
message(response, avatar_style="bottts", key=str(i)) |
|
|
|
|
with container: |
|
|
|
|
if len(history) > 0: |
|
|
|
|
for i, (query, response) in enumerate(history): |
|
|
|
|
message(query, avatar_style="big-smile", key=str(i) + "_user") |
|
|
|
|
message(response, avatar_style="bottts", key=str(i)) |
|
|
|
|
|
|
|
|
|
# if len(updates) < MAX_BOXES: |
|
|
|
|
# updates = updates + [""] * (MAX_BOXES - len(updates)) |
|
|
|
|
message(input, avatar_style="big-smile", key=str(len(history)) + "_user") |
|
|
|
|
st.write("AI正在回复:") |
|
|
|
|
with st.empty(): |
|
|
|
|
for response, history in model.stream_chat(tokenizer, input, history): |
|
|
|
|
query, response = history[-1] |
|
|
|
|
st.write(response) |
|
|
|
|
|
|
|
|
|
return history |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
container = st.container() |
|
|
|
|
|
|
|
|
|
# create a prompt text for the text generation |
|
|
|
|
prompt_text = st.text_area(label="用户命令输入", |
|
|
|
|
height = 100, |
|
|
|
|
placeholder="请在这儿输入您的命令") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'state' not in st.session_state: |
|
|
|
|
st.session_state['state'] = [] |
|
|
|
|
|
|
|
|
@ -53,4 +58,4 @@ if st.button("发送", key="predict"):
|
|
|
|
|
# text generation |
|
|
|
|
st.session_state["state"] = predict(prompt_text, st.session_state["state"]) |
|
|
|
|
|
|
|
|
|
st.balloons() |
|
|
|
|
st.session_state["state"] |