diff --git a/web_demo.py b/web_demo.py index 26de0ba..7576b99 100644 --- a/web_demo.py +++ b/web_demo.py @@ -6,7 +6,7 @@ Please refer to these links below for more information: 2. chatglm2: https://github.com/THUDM/ChatGLM2-6B 3. transformers: https://github.com/huggingface/transformers """ - +import argparse from dataclasses import asdict import streamlit as st @@ -24,13 +24,13 @@ def on_btn_click(): @st.cache_resource -def load_model(): +def load_model(model_path, tokenizer_path): model = ( - AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True) + AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, low_cpu_mem_usage=True) .to(torch.bfloat16) .cuda() ) - tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True, low_cpu_mem_usage=True) return model, tokenizer @@ -46,14 +46,21 @@ def prepare_generation_config(): return generation_config +system_desc="""<|System|>:You are an AI assistant whose name is InternLM (书生·浦语). +- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless. +- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文. +\n""" user_prompt = "<|User|>:{user}\n" robot_prompt = "<|Bot|>:{robot}\n" cur_query_prompt = "<|User|>:{user}\n<|Bot|>:" -def combine_history(prompt): +def combine_history(prompt, system=True): messages = st.session_state.messages total_prompt = "" + if system: + total_prompt += system_desc + for message in messages: cur_content = message["content"] if message["role"] == "user": @@ -64,13 +71,14 @@ def combine_history(prompt): raise RuntimeError total_prompt += cur_prompt total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt) + print(messages, total_prompt) return total_prompt -def main(): +def main(args): # torch.cuda.empty_cache() print("load model begin.") - model, tokenizer = load_model() + model, tokenizer = load_model(args.model_path, args.tokenizer_path) print("load model end.") user_avator = "doc/imgs/user.png" @@ -116,4 +124,8 @@ def main(): if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, default='internlm/internlm-chat-7b', help="Path to the model") + parser.add_argument("--tokenizer_path", type=str, default='internlm/internlm-chat-7b', help="Path to the tokenizer") + args = parser.parse_args() + main(args)