mirror of https://github.com/InternLM/InternLM
web
parent
aaaf4d7b0e
commit
0b00fd331d
28
web_demo.py
28
web_demo.py
|
@ -6,7 +6,7 @@ Please refer to these links below for more information:
|
|||
2. chatglm2: https://github.com/THUDM/ChatGLM2-6B
|
||||
3. transformers: https://github.com/huggingface/transformers
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from dataclasses import asdict
|
||||
|
||||
import streamlit as st
|
||||
|
@ -24,13 +24,13 @@ def on_btn_click():
|
|||
|
||||
|
||||
@st.cache_resource
|
||||
def load_model():
|
||||
def load_model(model_path, tokenizer_path):
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True)
|
||||
AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, low_cpu_mem_usage=True)
|
||||
.to(torch.bfloat16)
|
||||
.cuda()
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True, low_cpu_mem_usage=True)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
|
@ -46,14 +46,21 @@ def prepare_generation_config():
|
|||
return generation_config
|
||||
|
||||
|
||||
system_desc="""<|System|>:You are an AI assistant whose name is InternLM (书生·浦语).
|
||||
- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.
|
||||
- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.
|
||||
\n"""
|
||||
user_prompt = "<|User|>:{user}\n"
|
||||
robot_prompt = "<|Bot|>:{robot}<eoa>\n"
|
||||
cur_query_prompt = "<|User|>:{user}<eoh>\n<|Bot|>:"
|
||||
|
||||
|
||||
def combine_history(prompt):
|
||||
def combine_history(prompt, system=True):
|
||||
messages = st.session_state.messages
|
||||
total_prompt = ""
|
||||
if system:
|
||||
total_prompt += system_desc
|
||||
|
||||
for message in messages:
|
||||
cur_content = message["content"]
|
||||
if message["role"] == "user":
|
||||
|
@ -64,13 +71,14 @@ def combine_history(prompt):
|
|||
raise RuntimeError
|
||||
total_prompt += cur_prompt
|
||||
total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt)
|
||||
print(messages, total_prompt)
|
||||
return total_prompt
|
||||
|
||||
|
||||
def main():
|
||||
def main(args):
|
||||
# torch.cuda.empty_cache()
|
||||
print("load model begin.")
|
||||
model, tokenizer = load_model()
|
||||
model, tokenizer = load_model(args.model_path, args.tokenizer_path)
|
||||
print("load model end.")
|
||||
|
||||
user_avator = "doc/imgs/user.png"
|
||||
|
@ -116,4 +124,8 @@ def main():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, default='internlm/internlm-chat-7b', help="Path to the model")
|
||||
parser.add_argument("--tokenizer_path", type=str, default='internlm/internlm-chat-7b', help="Path to the tokenizer")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
|
Loading…
Reference in New Issue