pull/1340/merge
aleimu 2024-06-28 19:41:47 +08:00 committed by GitHub
commit f0f9113fa5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 23 additions and 18 deletions

5
api.py
View File

@ -2,6 +2,7 @@ from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime import uvicorn, json, datetime
import torch import torch
import asyncio
DEVICE = "cuda" DEVICE = "cuda"
DEVICE_ID = "0" DEVICE_ID = "0"
@ -29,6 +30,8 @@ async def create_item(request: Request):
max_length = json_post_list.get('max_length') max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p') top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature') temperature = json_post_list.get('temperature')
def _sync_chat(history):
response, history = model.chat(tokenizer, response, history = model.chat(tokenizer,
prompt, prompt,
history=history, history=history,
@ -48,6 +51,8 @@ async def create_item(request: Request):
torch_gc() torch_gc()
return answer return answer
return await asyncio.to_thread(_sync_chat, history=history)
if __name__ == '__main__': if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)