Update cli_demo.py

使用更先进的方式加载模型
pull/129/head
zxgov 2023-06-30 10:26:12 +08:00 committed by GitHub
parent 732eab22c8
commit e84a5f3c14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModel
import readline import readline
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda() model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, device='cuda')#.cuda()
# 多显卡支持使用下面两行代替上面一行将num_gpus改为你实际的显卡数量 # 多显卡支持使用下面两行代替上面一行将num_gpus改为你实际的显卡数量
# from utils import load_model_on_gpus # from utils import load_model_on_gpus
# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2) # model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)