pull/569/head
XHr 2024-01-04 13:06:44 +08:00 committed by GitHub
parent aaaf4d7b0e
commit 0b00fd331d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 8 deletions

View File

@ -6,7 +6,7 @@ Please refer to these links below for more information:
2. chatglm2: https://github.com/THUDM/ChatGLM2-6B
3. transformers: https://github.com/huggingface/transformers
"""
import argparse
from dataclasses import asdict
import streamlit as st
@ -24,13 +24,13 @@ def on_btn_click():
@st.cache_resource
def load_model():
def load_model(model_path, tokenizer_path):
model = (
AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True)
AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, low_cpu_mem_usage=True)
.to(torch.bfloat16)
.cuda()
)
tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True, low_cpu_mem_usage=True)
return model, tokenizer
@ -46,14 +46,21 @@ def prepare_generation_config():
return generation_config
system_desc="""<|System|>:You are an AI assistant whose name is InternLM (书生·浦语).
- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.
- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.
\n"""
user_prompt = "<|User|>:{user}\n"
robot_prompt = "<|Bot|>:{robot}<eoa>\n"
cur_query_prompt = "<|User|>:{user}<eoh>\n<|Bot|>:"
def combine_history(prompt):
def combine_history(prompt, system=True):
messages = st.session_state.messages
total_prompt = ""
if system:
total_prompt += system_desc
for message in messages:
cur_content = message["content"]
if message["role"] == "user":
@ -64,13 +71,14 @@ def combine_history(prompt):
raise RuntimeError
total_prompt += cur_prompt
total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt)
print(messages, total_prompt)
return total_prompt
def main():
def main(args):
# torch.cuda.empty_cache()
print("load model begin.")
model, tokenizer = load_model()
model, tokenizer = load_model(args.model_path, args.tokenizer_path)
print("load model end.")
user_avator = "doc/imgs/user.png"
@ -116,4 +124,8 @@ def main():
if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default='internlm/internlm-chat-7b', help="Path to the model")
parser.add_argument("--tokenizer_path", type=str, default='internlm/internlm-chat-7b', help="Path to the tokenizer")
args = parser.parse_args()
main(args)