mirror of https://github.com/InternLM/InternLM
120 lines
4.1 KiB
Python
120 lines
4.1 KiB
Python
|
"""
|
||
This script refers to the dialogue example of streamlit, the interactive generation code of chatglm2 and transformers.
|
|||
We mainly modified part of the code logic to adapt to the generation of our model.
|
|||
|
Please refer to these links below for more information:
|
||
1. streamlit chat example: https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
|
|||
2. chatglm2: https://github.com/THUDM/ChatGLM2-6B
|
|||
3. transformers: https://github.com/huggingface/transformers
|
|||
"""
|
|||
|
|||
from dataclasses import asdict
|
|||
|
|||
|
import streamlit as st
|
||
import torch
|
|||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||
from transformers.utils import logging
|
|||
|
|||
from tools.transformers.interface import GenerationConfig, generate_interactive
|
|||
|
|
||
logger = logging.get_logger(__name__)
|
|||
|
|
||
|
|||
def on_btn_click():
|
|||
del st.session_state.messages
|
|||
|
|||
|
|||
|
@st.cache_resource
|
||
def load_model():
|
|||
model = (
|
|||
AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True)
|
|||
.to(torch.bfloat16)
|
|||
.cuda()
|
|||
)
|
|||
|
tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True)
|
||
return model, tokenizer
|
|||
|
|||
|
|||
def prepare_generation_config():
|
|||
with st.sidebar:
|
|||
max_length = st.slider("Max Length", min_value=32, max_value=2048, value=2048)
|
|||
top_p = st.slider("Top P", 0.0, 1.0, 0.8, step=0.01)
|
|||
temperature = st.slider("Temperature", 0.0, 1.0, 0.7, step=0.01)
|
|||
|
st.button("Clear Chat History", on_click=on_btn_click)
|
||
|
|||
generation_config = GenerationConfig(max_length=max_length, top_p=top_p, temperature=temperature)
|
|||
|
|||
|
return generation_config
|
||
|
|||
|
|||
|
user_prompt = "<|User|>:{user}\n"
|
||
|
robot_prompt = "<|Bot|>:{robot}<eoa>\n"
|
||
cur_query_prompt = "<|User|>:{user}<eoh>\n<|Bot|>:"
|
|||
|
|||
|
|||
def combine_history(prompt):
|
|||
messages = st.session_state.messages
|
|||
total_prompt = ""
|
|||
for message in messages:
|
|||
cur_content = message["content"]
|
|||
if message["role"] == "user":
|
|||
cur_prompt = user_prompt.replace("{user}", cur_content)
|
|||
elif message["role"] == "robot":
|
|||
cur_prompt = robot_prompt.replace("{robot}", cur_content)
|
|||
else:
|
|||
raise RuntimeError
|
|||
total_prompt += cur_prompt
|
|||
total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt)
|
|||
return total_prompt
|
|||
|
|||
|
|||
def main():
|
|||
# torch.cuda.empty_cache()
|
|||
|
print("load model begin.")
|
||
model, tokenizer = load_model()
|
|||
print("load model end.")
|
|||
|
|||
user_avator = "docs/imgs/user.png"
|
|||
robot_avator = "docs/imgs/robot.png"
|
|||
|
|||
|
st.title("InternLM-Chat-7B")
|
||
|
|||
|
generation_config = prepare_generation_config()
|
||
|
|||
# Initialize chat history
|
|||
if "messages" not in st.session_state:
|
|||
st.session_state.messages = []
|
|||
|
|||
# Display chat messages from history on app rerun
|
|||
for message in st.session_state.messages:
|
|||
with st.chat_message(message["role"], avatar=message.get("avatar")):
|
|||
st.markdown(message["content"])
|
|||
|
|||
# Accept user input
|
|||
if prompt := st.chat_input("What is up?"):
|
|||
# Display user message in chat message container
|
|||
with st.chat_message("user", avatar=user_avator):
|
|||
st.markdown(prompt)
|
|||
real_prompt = combine_history(prompt)
|
|||
# Add user message to chat history
|
|||
st.session_state.messages.append({"role": "user", "content": prompt, "avatar": user_avator})
|
|||
|
|||
with st.chat_message("robot", avatar=robot_avator):
|
|||
message_placeholder = st.empty()
|
|||
for cur_response in generate_interactive(
|
|||
model=model,
|
|||
tokenizer=tokenizer,
|
|||
prompt=real_prompt,
|
|||
additional_eos_token_id=103028,
|
|||
**asdict(generation_config),
|
|||
):
|
|||
|
# Display robot response in chat message container
|
||
message_placeholder.markdown(cur_response + "▌")
|
|||
message_placeholder.markdown(cur_response)
|
|||
# Add robot response to chat history
|
|||
st.session_state.messages.append({"role": "robot", "content": cur_response, "avatar": robot_avator})
|
|||
|
torch.cuda.empty_cache()
|
||
|
|||
|
|
||
if __name__ == "__main__":
|
|||
main()
|