mirror of https://github.com/InternLM/InternLM
Update web_demo.py
parent
e66bd817c1
commit
830c0cc2a8
39
web_demo.py
39
web_demo.py
|
@ -6,6 +6,7 @@ Please refer to these links below for more information:
|
||||||
3. transformers: https://github.com/huggingface/transformers
|
3. transformers: https://github.com/huggingface/transformers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -196,41 +197,24 @@ def combine_history(prompt):
|
||||||
total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt)
|
total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt)
|
||||||
return total_prompt
|
return total_prompt
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='Chat with the model, please add "--" between web_demo.py and Args in terminal')
|
||||||
|
parser.add_argument('--model_path', default='internlm/internlm-7b', help='path of the InternLm model')
|
||||||
|
parser.add_argument('--max_value', default=2048, help='the max length of the generated text')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
def main():
|
def main(args):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
while True:
|
|
||||||
choice = input("Please select a number from(\033[33m1\033[0m.Internlm-7b/\033[33m2\033[0m.Internlm-chat-7b/\033[33m3\033[0m.Internlm-chat-7b-8k): ")
|
|
||||||
if choice == '1':
|
|
||||||
print("You select \033[33mInternlm-7b\033[0m")
|
|
||||||
model_name = "Internlm-7b"
|
|
||||||
model_path = "internlm/internlm-7b"
|
|
||||||
max_value = 2048
|
|
||||||
break
|
|
||||||
elif choice == '2':
|
|
||||||
print("You select \033[33mInternlm-chat-7b\033[0m")
|
|
||||||
model_name = "Internlm-chat-7b"
|
|
||||||
model_path = "internlm/internlm-chat-7b"
|
|
||||||
max_value = 2048
|
|
||||||
break
|
|
||||||
elif choice == '3':
|
|
||||||
print("You select \033[33mInternlm-chat-7b-8k\033[0m")
|
|
||||||
model_name = "Internlm-chat-7b-8k"
|
|
||||||
model_path = "internlm/internlm-chat-7b-8k"
|
|
||||||
max_value = 8192
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print("Invalid selection")
|
|
||||||
print("load model begin.")
|
print("load model begin.")
|
||||||
model, tokenizer = load_model(model_path)
|
model, tokenizer = load_model(args.model_path)
|
||||||
print("load model end.")
|
print("load model end.")
|
||||||
|
|
||||||
user_avator = "doc/imgs/user.png"
|
user_avator = "doc/imgs/user.png"
|
||||||
robot_avator = "doc/imgs/robot.png"
|
robot_avator = "doc/imgs/robot.png"
|
||||||
|
|
||||||
st.title(model_name)
|
st.title("InternLm-7b")
|
||||||
|
|
||||||
generation_config = prepare_generation_config(max_value)
|
generation_config = prepare_generation_config(args.max_value)
|
||||||
|
|
||||||
# Initialize chat history
|
# Initialize chat history
|
||||||
if "messages" not in st.session_state:
|
if "messages" not in st.session_state:
|
||||||
|
@ -264,5 +248,6 @@ def main():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
args = parse_args()
|
||||||
|
main(args)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue