Update api.py

同步方法异步化,增加api调用的并发性能,防止请求之间在服务层相互阻塞
pull/1340/head
aleimu 2023-07-18 17:50:19 +08:00 committed by GitHub
parent db237cc258
commit 044cf323a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 27 additions and 18 deletions

11
api.py
View File

@ -1,7 +1,9 @@
import torch
import asyncio
import concurrent.futures
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime import uvicorn, json, datetime
import torch
DEVICE = "cuda" DEVICE = "cuda"
DEVICE_ID = "0" DEVICE_ID = "0"
@ -29,6 +31,8 @@ async def create_item(request: Request):
max_length = json_post_list.get('max_length') max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p') top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature') temperature = json_post_list.get('temperature')
def _sync_chat(history):
response, history = model.chat(tokenizer, response, history = model.chat(tokenizer,
prompt, prompt,
history=history, history=history,
@ -48,6 +52,11 @@ async def create_item(request: Request):
torch_gc() torch_gc()
return answer return answer
loop = asyncio.get_event_loop()
executor = concurrent.futures.ThreadPoolExecutor()
answer = await loop.run_in_executor(executor, _sync_chat, history)
return answer
if __name__ == '__main__': if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)