使用异步的上下文管理器,启动加速

pull/547/head
Whitroom 2023-04-12 14:48:49 +08:00
parent 24dada9d5a
commit be7e14ce45
1 changed files with 28 additions and 34 deletions

62
api.py
View File

@ -1,9 +1,9 @@
import datetime import datetime
import json from contextlib import asynccontextmanager
import torch import torch
import uvicorn import uvicorn
from fastapi import FastAPI, Request from fastapi import FastAPI
from pydantic import BaseModel from pydantic import BaseModel
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
@ -11,6 +11,21 @@ DEVICE = "cuda"
DEVICE_ID = "0" DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
models['chat'] = AutoModel.from_pretrained(
"THUDM/models",
trust_remote_code=True).half().cuda()
models['chat'].eval()
models['tokenizer'] = AutoTokenizer.from_pretrained(
"THUDM/models",
trust_remote_code=True)
yield
for model in models.values():
del model
torch_gc()
def torch_gc(): def torch_gc():
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -18,8 +33,7 @@ def torch_gc():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app = FastAPI()
class Item(BaseModel): class Item(BaseModel):
prompt: str prompt: str
@ -35,38 +49,18 @@ class Answer(BaseModel):
time: str time: str
@app.post("/") @app.post("/")
async def create_item(request: Request): async def create_item(item: Item):
global model, tokenizer response, history = models['chat'].chat(
json_post_raw = await request.json() models['tokenizer'],
json_post = json.dumps(json_post_raw) item.prompt,
json_post_list = json.loads(json_post) history=item.history,
prompt = json_post_list.get('prompt') max_length=item.max_length,
history = json_post_list.get('history') top_p=item.top_p,
max_length = json_post_list.get('max_length') temperature=item.temperature)
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95)
now = datetime.datetime.now() now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S") time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = { print(f"[{time}] prompt: '{item.prompt}', response: '{response}'")
"response": response, return Answer(response=response, history=history, status=200, time=time)
"history": history,
"status": 200,
"time": time
}
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
torch_gc()
return answer
if __name__ == '__main__': if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model.eval()
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)