mirror of https://github.com/InternLM/InternLM
Update web_demo.py
parent
fdb6a59c6a
commit
e66bd817c1
45
web_demo.py
45
web_demo.py
|
@ -150,15 +150,15 @@ class GenerationConfig:
|
||||||
|
|
||||||
|
|
||||||
@st.cache_resource
|
@st.cache_resource
|
||||||
def load_model():
|
def load_model(model_path):
|
||||||
model = AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True).to(torch.bfloat16).cuda()
|
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda()
|
||||||
tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def prepare_generation_config():
|
def prepare_generation_config(max_value):
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
max_length = st.slider("Max Length", min_value=32, max_value=8192, value=2048)
|
max_length = st.slider("Max Length", min_value=32, max_value=max_value, value=2048)
|
||||||
top_p = st.slider(
|
top_p = st.slider(
|
||||||
'Top P', 0.0, 1.0, 0.8, step=0.01
|
'Top P', 0.0, 1.0, 0.8, step=0.01
|
||||||
)
|
)
|
||||||
|
@ -199,16 +199,38 @@ def combine_history(prompt):
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
while True:
|
||||||
|
choice = input("Please select a number from(\033[33m1\033[0m.Internlm-7b/\033[33m2\033[0m.Internlm-chat-7b/\033[33m3\033[0m.Internlm-chat-7b-8k): ")
|
||||||
|
if choice == '1':
|
||||||
|
print("You select \033[33mInternlm-7b\033[0m")
|
||||||
|
model_name = "Internlm-7b"
|
||||||
|
model_path = "internlm/internlm-7b"
|
||||||
|
max_value = 2048
|
||||||
|
break
|
||||||
|
elif choice == '2':
|
||||||
|
print("You select \033[33mInternlm-chat-7b\033[0m")
|
||||||
|
model_name = "Internlm-chat-7b"
|
||||||
|
model_path = "internlm/internlm-chat-7b"
|
||||||
|
max_value = 2048
|
||||||
|
break
|
||||||
|
elif choice == '3':
|
||||||
|
print("You select \033[33mInternlm-chat-7b-8k\033[0m")
|
||||||
|
model_name = "Internlm-chat-7b-8k"
|
||||||
|
model_path = "internlm/internlm-chat-7b-8k"
|
||||||
|
max_value = 8192
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print("Invalid selection")
|
||||||
print("load model begin.")
|
print("load model begin.")
|
||||||
model, tokenizer = load_model()
|
model, tokenizer = load_model(model_path)
|
||||||
print("load model end.")
|
print("load model end.")
|
||||||
|
|
||||||
user_avator = "doc/imgs/user.png"
|
user_avator = "doc/imgs/user.png"
|
||||||
robot_avator = "doc/imgs/robot.png"
|
robot_avator = "doc/imgs/robot.png"
|
||||||
|
|
||||||
st.title("InternLM-Chat-7B")
|
st.title(model_name)
|
||||||
|
|
||||||
generation_config = prepare_generation_config()
|
generation_config = prepare_generation_config(max_value)
|
||||||
|
|
||||||
# Initialize chat history
|
# Initialize chat history
|
||||||
if "messages" not in st.session_state:
|
if "messages" not in st.session_state:
|
||||||
|
@ -244,10 +266,3 @@ def main():
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue