Update api.py

同步方法异步化,增加api调用的并发性能,防止请求之间在服务层相互阻塞
pull/1340/head
aleimu 2023-07-18 17:50:19 +08:00 committed by GitHub
parent db237cc258
commit 044cf323a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 27 additions and 18 deletions

45
api.py
View File

@ -1,7 +1,9 @@
import torch
import asyncio
import concurrent.futures
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime
import torch
DEVICE = "cuda"
DEVICE_ID = "0"
@ -29,23 +31,30 @@ async def create_item(request: Request):
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95)
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": response,
"history": history,
"status": 200,
"time": time
}
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
torch_gc()
def _sync_chat(history):
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95)
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": response,
"history": history,
"status": 200,
"time": time
}
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
torch_gc()
return answer
loop = asyncio.get_event_loop()
executor = concurrent.futures.ThreadPoolExecutor()
answer = await loop.run_in_executor(executor, _sync_chat, history)
return answer