From 770676fdd51a4c3f25cfb0181019a42e69d6bb61 Mon Sep 17 00:00:00 2001 From: DealiAxy Date: Fri, 19 May 2023 17:33:40 +0800 Subject: [PATCH] =?UTF-8?q?=E2=98=80=20feat:=20=E9=87=8D=E5=86=99=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api.py | 78 +++++++++++++++++++++++++++++++++------------------------- 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/api.py b/api.py index 693c70a..4ecaa27 100644 --- a/api.py +++ b/api.py @@ -1,40 +1,49 @@ +import json +import datetime +import torch +import uvicorn +from typing import List from fastapi import FastAPI, Request from transformers import AutoTokenizer, AutoModel -import uvicorn, json, datetime -import torch - -DEVICE = "cuda" -DEVICE_ID = "0" -CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE +from pydantic import BaseModel +from utils import load_model_on_gpus -def torch_gc(): +devices_list = [ + 'cuda:0', + 'cuda:1' +] + + +def _torch_gc(): if torch.cuda.is_available(): - with torch.cuda.device(CUDA_DEVICE): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + for item in devices_list: + with torch.cuda.device(item): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +class Question(BaseModel): + prompt: str + history: List[str] = [] + max_length: int = 2048 + top_p: float = 0.7 + temperature: float = 0.95 app = FastAPI() -@app.post("/") -async def create_item(request: Request): - global model, tokenizer - json_post_raw = await request.json() - json_post = json.dumps(json_post_raw) - json_post_list = json.loads(json_post) - prompt = json_post_list.get('prompt') - history = json_post_list.get('history') - max_length = json_post_list.get('max_length') - top_p = json_post_list.get('top_p') - temperature = json_post_list.get('temperature') - response, history = model.chat(tokenizer, - prompt, - history=history, - max_length=max_length if max_length else 2048, - top_p=top_p if top_p else 0.7, - temperature=temperature if temperature else 0.95) +@app.post('/chat/') +async def chat(question: Question): + response, history = model.chat( + tokenizer, + question.prompt, + history=question.history, + max_length=question.max_length, + top_p=question.top_p, + temperature=question.temperature + ) now = datetime.datetime.now() time = now.strftime("%Y-%m-%d %H:%M:%S") answer = { @@ -43,14 +52,15 @@ async def create_item(request: Request): "status": 200, "time": time } - log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' - print(log) - torch_gc() + _torch_gc() return answer -if __name__ == '__main__': - tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) - model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +if __name__ == "__main__": + tokenizer = AutoTokenizer.from_pretrained( + "THUDM/chatglm-6b", trust_remote_code=True + ) + model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2) + # model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model.eval() - uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) + uvicorn.run(app, host="127.0.0.1", port=11001, workers=1)