Update web_demo2.py

使用更先进的方式加载模型
pull/129/head
zxgov 2023-06-30 10:27:48 +08:00 committed by GitHub
parent 07d463034b
commit bcb8514820
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -13,7 +13,7 @@ st.set_page_config(
@st.cache_resource
def get_model():
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, device='cuda')#.cuda()
# 多显卡支持使用下面两行代替上面一行将num_gpus改为你实际的显卡数量
# from utils import load_model_on_gpus
# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)