使用pydantic定义输入和输出结构.

pull/1018/head
Yu Xin 2023-05-15 13:57:02 +08:00
parent 9cc1bd5136
commit faf8353d9f
1 changed files with 43 additions and 41 deletions

84
api.py
View File

@ -1,14 +1,30 @@
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime
import torch
import threading
import asyncio import asyncio
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import AutoTokenizer, AutoModel
from pydantic import BaseModel
import uvicorn, datetime
import torch
DEVICE = "cuda" DEVICE = "cuda"
DEVICE_ID = "0" DEVICE_ID = "0"
EXECUTOR_POOL_SIZE = 10
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
class Params(BaseModel):
prompt: str = 'hello'
history: list[list[str]] = []
max_length: int = 2048
top_p: float = 0.7
temperature: float = 0.95
class Answer(BaseModel):
status: int = 200
time: str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
response: str
history: list[list[str]] = []
def torch_gc(): def torch_gc():
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -21,45 +37,31 @@ app = FastAPI()
import concurrent import concurrent
from functools import partial from functools import partial
pool = concurrent.futures.ThreadPoolExecutor(10) pool = concurrent.futures.ThreadPoolExecutor(EXECUTOR_POOL_SIZE)
@app.post("/") @app.post("/chat")
async def _create_item(request: Request): async def create_chat(params: Params) -> Answer:
global model, tokenizer global model, tokenizer
json_post_raw = await request.json() if EXECUTOR_POOL_SIZE != 0:
json_post = json.dumps(json_post_raw) loop = asyncio.get_event_loop()
json_post_list = json.loads(json_post) response, history = await loop.run_in_executor(pool, partial(model.chat,
prompt = json_post_list.get('prompt') tokenizer,
history = json_post_list.get('history') params.prompt,
max_length = json_post_list.get('max_length') history=params.history,
top_p = json_post_list.get('top_p') max_length=params.max_length,
temperature = json_post_list.get('temperature') top_p=params.top_p,
loop = asyncio.get_event_loop() temperature=params.temperature))
response, history = await loop.run_in_executor(pool,partial(model.chat,tokenizer, else:
prompt, response, history = model.chat(tokenizer,
history=history, params.prompt,
max_length=max_length if max_length else 2048, history=params.history,
top_p=top_p if top_p else 0.7, max_length=params.max_length,
temperature=temperature if temperature else 0.95)) top_p=params.top_p,
now = datetime.datetime.now() temperature=params.temperature)
time = now.strftime("%Y-%m-%d %H:%M:%S") answer_ok = Answer(response=response, history=history)
answer = { # print(answer_ok.json())
"response": response,
"history": history,
"status": 200,
"time": time
}
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
torch_gc() torch_gc()
return answer return answer_ok
async def create_item(request: Request):
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as pool:
result = await loop.run_in_executor(pool,_create_item, request)
print(result)
return result
if __name__ == '__main__': if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)