From 1fa2608dd15fb579f739edb3e6d58e03e7c797f5 Mon Sep 17 00:00:00 2001 From: fenglui Date: Sun, 23 Jul 2023 00:53:16 +0800 Subject: [PATCH] add api_add_middleware.py with support with CORS --- api_add_middleware.py | 68 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 api_add_middleware.py diff --git a/api_add_middleware.py b/api_add_middleware.py new file mode 100644 index 0000000..22a8cc8 --- /dev/null +++ b/api_add_middleware.py @@ -0,0 +1,68 @@ +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from transformers import AutoTokenizer, AutoModel +import uvicorn, json, datetime +import torch + +DEVICE = "cuda" +DEVICE_ID = "0" +CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE + + +def torch_gc(): + if torch.cuda.is_available(): + with torch.cuda.device(CUDA_DEVICE): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.post("/") +async def create_item(request: Request): + global model, tokenizer + json_post_raw = await request.json() + json_post = json.dumps(json_post_raw) + json_post_list = json.loads(json_post) + prompt = json_post_list.get('prompt') + history = json_post_list.get('history') + max_length = json_post_list.get('max_length') + top_p = json_post_list.get('top_p') + temperature = json_post_list.get('temperature') + response, history = model.chat(tokenizer, + prompt, + history=history, + max_length=max_length if max_length else 2048, + top_p=top_p if top_p else 0.7, + temperature=temperature if temperature else 0.95) + now = datetime.datetime.now() + time = now.strftime("%Y-%m-%d %H:%M:%S") + answer = { + "response": response, + "history": history, + "status": 200, + "time": time + } + log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' + print(log) + torch_gc() + return answer + + +if __name__ == '__main__': + tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) + model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda() + # 多显卡支持,使用下面三行代替上面两行,将num_gpus改为你实际的显卡数量 + # model_path = "THUDM/chatglm2-6b" + # tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + # model = load_model_on_gpus(model_path, num_gpus=2) + model.eval() + uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)