diff --git a/api.py b/api.py index 693c70a..ae30fd5 100644 --- a/api.py +++ b/api.py @@ -2,6 +2,7 @@ from fastapi import FastAPI, Request from transformers import AutoTokenizer, AutoModel import uvicorn, json, datetime import torch +from sse_starlette.sse import EventSourceResponse DEVICE = "cuda" DEVICE_ID = "0" @@ -18,17 +19,23 @@ def torch_gc(): app = FastAPI() -@app.post("/") -async def create_item(request: Request): - global model, tokenizer - json_post_raw = await request.json() - json_post = json.dumps(json_post_raw) - json_post_list = json.loads(json_post) - prompt = json_post_list.get('prompt') - history = json_post_list.get('history') - max_length = json_post_list.get('max_length') - top_p = json_post_list.get('top_p') - temperature = json_post_list.get('temperature') +def predict_stream(tokenizer, prompt, history, max_length, top_p, temperature): + for response, history in model.stream_chat(tokenizer, prompt, history, max_length=max_length, top_p=top_p, + temperature=temperature): + now = datetime.datetime.now() + time = now.strftime("%Y-%m-%d %H:%M:%S") + yield json.dumps({ + 'response': response, + 'history': history, + 'status': 200, + 'time': time + }) + log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' + print(log) + return torch_gc() + + +def predict(tokenizer, prompt, history, max_length, top_p, temperature): response, history = model.chat(tokenizer, prompt, history=history, @@ -48,6 +55,25 @@ async def create_item(request: Request): torch_gc() return answer +@app.post("/") +async def create_item(request: Request): + global model, tokenizer + json_post_raw = await request.json() + json_post = json.dumps(json_post_raw) + json_post_list = json.loads(json_post) + prompt = json_post_list.get('prompt') + history = json_post_list.get('history') + max_length = json_post_list.get('max_length') + top_p = json_post_list.get('top_p') + temperature = json_post_list.get('temperature') + stream = json_post_list.get('stream') + if stream: + res = predict_stream(tokenizer, prompt, history, max_length, top_p, temperature) + return EventSourceResponse(res) + else: + answer = predict(tokenizer, prompt, history, max_length, top_p, temperature) + return answer + if __name__ == '__main__': tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)