使用pydantic定义输入和输出结构.

pull/1018/head
Yu Xin 2023-05-15 13:57:02 +08:00
parent 9cc1bd5136
commit faf8353d9f
1 changed files with 43 additions and 41 deletions

84
api.py
View File

@ -1,14 +1,30 @@
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime
import torch
import threading
import asyncio
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import AutoTokenizer, AutoModel
from pydantic import BaseModel
import uvicorn, datetime
import torch
DEVICE = "cuda"
DEVICE_ID = "0"
EXECUTOR_POOL_SIZE = 10
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
class Params(BaseModel):
prompt: str = 'hello'
history: list[list[str]] = []
max_length: int = 2048
top_p: float = 0.7
temperature: float = 0.95
class Answer(BaseModel):
status: int = 200
time: str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
response: str
history: list[list[str]] = []
def torch_gc():
if torch.cuda.is_available():
@ -21,45 +37,31 @@ app = FastAPI()
import concurrent
from functools import partial
pool = concurrent.futures.ThreadPoolExecutor(10)
pool = concurrent.futures.ThreadPoolExecutor(EXECUTOR_POOL_SIZE)
@app.post("/")
async def _create_item(request: Request):
@app.post("/chat")
async def create_chat(params: Params) -> Answer:
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = json_post_list.get('history')
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
loop = asyncio.get_event_loop()
response, history = await loop.run_in_executor(pool,partial(model.chat,tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95))
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": response,
"history": history,
"status": 200,
"time": time
}
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
if EXECUTOR_POOL_SIZE != 0:
loop = asyncio.get_event_loop()
response, history = await loop.run_in_executor(pool, partial(model.chat,
tokenizer,
params.prompt,
history=params.history,
max_length=params.max_length,
top_p=params.top_p,
temperature=params.temperature))
else:
response, history = model.chat(tokenizer,
params.prompt,
history=params.history,
max_length=params.max_length,
top_p=params.top_p,
temperature=params.temperature)
answer_ok = Answer(response=response, history=history)
# print(answer_ok.json())
torch_gc()
return answer
async def create_item(request: Request):
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as pool:
result = await loop.run_in_executor(pool,_create_item, request)
print(result)
return result
return answer_ok
if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)