Update api.py

使用更先进的方式加载模型
pull/129/head
zxgov 2023-06-30 10:25:01 +08:00 committed by GitHub
parent 3e3bd516d7
commit 732eab22c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

2
api.py
View File

@ -51,7 +51,7 @@ async def create_item(request: Request):
if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, device='cuda')#.cuda()
# 多显卡支持使用下面三行代替上面两行将num_gpus改为你实际的显卡数量
# model_path = "THUDM/chatglm2-6b"
# tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)