使用异步的上下文管理器,启动加速

pull/547/head
Whitroom 2023-04-12 14:48:49 +08:00
parent 24dada9d5a
commit be7e14ce45
1 changed files with 28 additions and 34 deletions

62
api.py
View File

@ -1,9 +1,9 @@
import datetime
import json
from contextlib import asynccontextmanager
import torch
import uvicorn
from fastapi import FastAPI, Request
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModel, AutoTokenizer
@ -11,6 +11,21 @@ DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
models['chat'] = AutoModel.from_pretrained(
"THUDM/models",
trust_remote_code=True).half().cuda()
models['chat'].eval()
models['tokenizer'] = AutoTokenizer.from_pretrained(
"THUDM/models",
trust_remote_code=True)
yield
for model in models.values():
del model
torch_gc()
def torch_gc():
if torch.cuda.is_available():
@ -18,8 +33,7 @@ def torch_gc():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI()
app = FastAPI(lifespan=lifespan)
class Item(BaseModel):
prompt: str
@ -35,38 +49,18 @@ class Answer(BaseModel):
time: str
@app.post("/")
async def create_item(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = json_post_list.get('history')
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95)
async def create_item(item: Item):
response, history = models['chat'].chat(
models['tokenizer'],
item.prompt,
history=item.history,
max_length=item.max_length,
top_p=item.top_p,
temperature=item.temperature)
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": response,
"history": history,
"status": 200,
"time": time
}
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
torch_gc()
return answer
print(f"[{time}] prompt: '{item.prompt}', response: '{response}'")
return Answer(response=response, history=history, status=200, time=time)
if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model.eval()
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)