From 830c0cc2a87e577ec4610024e572785045df66db Mon Sep 17 00:00:00 2001 From: LYMDLUT <70597027+LYMDLUT@users.noreply.github.com> Date: Fri, 14 Jul 2023 00:52:35 +0800 Subject: [PATCH] Update web_demo.py --- web_demo.py | 39 ++++++++++++--------------------------- 1 file changed, 12 insertions(+), 27 deletions(-) diff --git a/web_demo.py b/web_demo.py index 1b47f28..8e9f298 100644 --- a/web_demo.py +++ b/web_demo.py @@ -6,6 +6,7 @@ Please refer to these links below for more information: 3. transformers: https://github.com/huggingface/transformers """ +import argparse import streamlit as st import torch import torch.nn as nn @@ -196,41 +197,24 @@ def combine_history(prompt): total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt) return total_prompt +def parse_args(): + parser = argparse.ArgumentParser(description='Chat with the model, please add "--" between web_demo.py and Args in terminal') + parser.add_argument('--model_path', default='internlm/internlm-7b', help='path of the InternLm model') + parser.add_argument('--max_value', default=2048, help='the max length of the generated text') + return parser.parse_args() -def main(): +def main(args): torch.cuda.empty_cache() - while True: - choice = input("Please select a number from(\033[33m1\033[0m.Internlm-7b/\033[33m2\033[0m.Internlm-chat-7b/\033[33m3\033[0m.Internlm-chat-7b-8k): ") - if choice == '1': - print("You select \033[33mInternlm-7b\033[0m") - model_name = "Internlm-7b" - model_path = "internlm/internlm-7b" - max_value = 2048 - break - elif choice == '2': - print("You select \033[33mInternlm-chat-7b\033[0m") - model_name = "Internlm-chat-7b" - model_path = "internlm/internlm-chat-7b" - max_value = 2048 - break - elif choice == '3': - print("You select \033[33mInternlm-chat-7b-8k\033[0m") - model_name = "Internlm-chat-7b-8k" - model_path = "internlm/internlm-chat-7b-8k" - max_value = 8192 - break - else: - print("Invalid selection") print("load model begin.") - model, tokenizer = load_model(model_path) + model, tokenizer = load_model(args.model_path) print("load model end.") user_avator = "doc/imgs/user.png" robot_avator = "doc/imgs/robot.png" - st.title(model_name) + st.title("InternLm-7b") - generation_config = prepare_generation_config(max_value) + generation_config = prepare_generation_config(args.max_value) # Initialize chat history if "messages" not in st.session_state: @@ -264,5 +248,6 @@ def main(): if __name__ == "__main__": - main() + args = parse_args() + main(args)