☀ feat: 重写接口

pull/1066/head
DealiAxy 2023-05-19 17:33:40 +08:00
parent bc6695b7f2
commit 770676fdd5
1 changed files with 44 additions and 34 deletions

78
api.py
View File

@ -1,40 +1,49 @@
import json
import datetime
import torch
import uvicorn
from typing import List
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime from pydantic import BaseModel
import torch from utils import load_model_on_gpus
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
def torch_gc(): devices_list = [
'cuda:0',
'cuda:1'
]
def _torch_gc():
if torch.cuda.is_available(): if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE): for item in devices_list:
torch.cuda.empty_cache() with torch.cuda.device(item):
torch.cuda.ipc_collect() torch.cuda.empty_cache()
torch.cuda.ipc_collect()
class Question(BaseModel):
prompt: str
history: List[str] = []
max_length: int = 2048
top_p: float = 0.7
temperature: float = 0.95
app = FastAPI() app = FastAPI()
@app.post("/") @app.post('/chat/')
async def create_item(request: Request): async def chat(question: Question):
global model, tokenizer response, history = model.chat(
json_post_raw = await request.json() tokenizer,
json_post = json.dumps(json_post_raw) question.prompt,
json_post_list = json.loads(json_post) history=question.history,
prompt = json_post_list.get('prompt') max_length=question.max_length,
history = json_post_list.get('history') top_p=question.top_p,
max_length = json_post_list.get('max_length') temperature=question.temperature
top_p = json_post_list.get('top_p') )
temperature = json_post_list.get('temperature')
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95)
now = datetime.datetime.now() now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S") time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = { answer = {
@ -43,14 +52,15 @@ async def create_item(request: Request):
"status": 200, "status": 200,
"time": time "time": time
} }
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' _torch_gc()
print(log)
torch_gc()
return answer return answer
if __name__ == '__main__': if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() "THUDM/chatglm-6b", trust_remote_code=True
)
model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2)
# model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model.eval() model.eval()
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) uvicorn.run(app, host="127.0.0.1", port=11001, workers=1)