mirror of https://github.com/THUDM/ChatGLM2-6B
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
70 lines
2.4 KiB
70 lines
2.4 KiB
from transformers import AutoModel, AutoTokenizer |
|
import streamlit as st |
|
|
|
|
|
st.set_page_config( |
|
page_title="ChatGLM2-6b 演示", |
|
page_icon=":robot:", |
|
layout='wide' |
|
) |
|
|
|
|
|
@st.cache_resource |
|
def get_model(): |
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) |
|
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda() |
|
# 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量 |
|
# from utils import load_model_on_gpus |
|
# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2) |
|
model = model.eval() |
|
return tokenizer, model |
|
|
|
|
|
tokenizer, model = get_model() |
|
|
|
st.title("ChatGLM2-6B") |
|
|
|
max_length = st.sidebar.slider( |
|
'max_length', 0, 32768, 8192, step=1 |
|
) |
|
top_p = st.sidebar.slider( |
|
'top_p', 0.0, 1.0, 0.8, step=0.01 |
|
) |
|
temperature = st.sidebar.slider( |
|
'temperature', 0.0, 1.0, 0.8, step=0.01 |
|
) |
|
|
|
if 'history' not in st.session_state: |
|
st.session_state.history = [] |
|
|
|
if 'past_key_values' not in st.session_state: |
|
st.session_state.past_key_values = None |
|
|
|
for i, (query, response) in enumerate(st.session_state.history): |
|
with st.chat_message(name="user", avatar="user"): |
|
st.markdown(query) |
|
with st.chat_message(name="assistant", avatar="assistant"): |
|
st.markdown(response) |
|
with st.chat_message(name="user", avatar="user"): |
|
input_placeholder = st.empty() |
|
with st.chat_message(name="assistant", avatar="assistant"): |
|
message_placeholder = st.empty() |
|
|
|
prompt_text = st.text_area(label="用户命令输入", |
|
height=100, |
|
placeholder="请在这儿输入您的命令") |
|
|
|
button = st.button("发送", key="predict") |
|
|
|
if button: |
|
input_placeholder.markdown(prompt_text) |
|
history, past_key_values = st.session_state.history, st.session_state.past_key_values |
|
for response, history, past_key_values in model.stream_chat(tokenizer, prompt_text, history, |
|
past_key_values=past_key_values, |
|
max_length=max_length, top_p=top_p, |
|
temperature=temperature, |
|
return_past_key_values=True): |
|
message_placeholder.markdown(response) |
|
|
|
st.session_state.history = history |
|
st.session_state.past_key_values = past_key_values
|
|
|