api.py支持并发服务.

pull/1018/head
Yu Xin 2023-05-14 23:19:08 +08:00
parent 2c25e52421
commit 9cc1bd5136
1 changed files with 16 additions and 3 deletions

19
api.py
View File

@ -2,6 +2,8 @@ from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime import uvicorn, json, datetime
import torch import torch
import threading
import asyncio
DEVICE = "cuda" DEVICE = "cuda"
DEVICE_ID = "0" DEVICE_ID = "0"
@ -17,9 +19,12 @@ def torch_gc():
app = FastAPI() app = FastAPI()
import concurrent
from functools import partial
pool = concurrent.futures.ThreadPoolExecutor(10)
@app.post("/") @app.post("/")
async def create_item(request: Request): async def _create_item(request: Request):
global model, tokenizer global model, tokenizer
json_post_raw = await request.json() json_post_raw = await request.json()
json_post = json.dumps(json_post_raw) json_post = json.dumps(json_post_raw)
@ -29,12 +34,13 @@ async def create_item(request: Request):
max_length = json_post_list.get('max_length') max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p') top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature') temperature = json_post_list.get('temperature')
response, history = model.chat(tokenizer, loop = asyncio.get_event_loop()
response, history = await loop.run_in_executor(pool,partial(model.chat,tokenizer,
prompt, prompt,
history=history, history=history,
max_length=max_length if max_length else 2048, max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7, top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95) temperature=temperature if temperature else 0.95))
now = datetime.datetime.now() now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S") time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = { answer = {
@ -48,9 +54,16 @@ async def create_item(request: Request):
torch_gc() torch_gc()
return answer return answer
async def create_item(request: Request):
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as pool:
result = await loop.run_in_executor(pool,_create_item, request)
print(result)
return result
if __name__ == '__main__': if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model.eval() model.eval()
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)