diff --git a/web_demo.py b/web_demo.py index e37f068..1b47f28 100644 --- a/web_demo.py +++ b/web_demo.py @@ -150,15 +150,15 @@ class GenerationConfig: @st.cache_resource -def load_model(): - model = AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True).to(torch.bfloat16).cuda() - tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True) +def load_model(model_path): + model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda() + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) return model, tokenizer -def prepare_generation_config(): +def prepare_generation_config(max_value): with st.sidebar: - max_length = st.slider("Max Length", min_value=32, max_value=8192, value=2048) + max_length = st.slider("Max Length", min_value=32, max_value=max_value, value=2048) top_p = st.slider( 'Top P', 0.0, 1.0, 0.8, step=0.01 ) @@ -199,16 +199,38 @@ def combine_history(prompt): def main(): torch.cuda.empty_cache() + while True: + choice = input("Please select a number from(\033[33m1\033[0m.Internlm-7b/\033[33m2\033[0m.Internlm-chat-7b/\033[33m3\033[0m.Internlm-chat-7b-8k): ") + if choice == '1': + print("You select \033[33mInternlm-7b\033[0m") + model_name = "Internlm-7b" + model_path = "internlm/internlm-7b" + max_value = 2048 + break + elif choice == '2': + print("You select \033[33mInternlm-chat-7b\033[0m") + model_name = "Internlm-chat-7b" + model_path = "internlm/internlm-chat-7b" + max_value = 2048 + break + elif choice == '3': + print("You select \033[33mInternlm-chat-7b-8k\033[0m") + model_name = "Internlm-chat-7b-8k" + model_path = "internlm/internlm-chat-7b-8k" + max_value = 8192 + break + else: + print("Invalid selection") print("load model begin.") - model, tokenizer = load_model() + model, tokenizer = load_model(model_path) print("load model end.") user_avator = "doc/imgs/user.png" robot_avator = "doc/imgs/robot.png" - st.title("InternLM-Chat-7B") + st.title(model_name) - generation_config = prepare_generation_config() + generation_config = prepare_generation_config(max_value) # Initialize chat history if "messages" not in st.session_state: @@ -244,10 +266,3 @@ def main(): if __name__ == "__main__": main() - - - - - - -