From 35f45dcf1bac861daa2384f75b50572af96cdde3 Mon Sep 17 00:00:00 2001 From: littlestone0806 <42195561+littlestone0806@users.noreply.github.com> Date: Tue, 18 Apr 2023 16:16:43 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0SSE=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E8=BE=93=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 发送请求时,stream参数为1请求流式输出。 发送请求格式如下: curl -X POST "http://127.0.0.1:8000" \ -H 'Content-Type: application/json' \ -d '{"prompt": "你好", "history": [], "stream":1}' 回复按照SSE流式输出格式 data:{ "response":"你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。", "history":[["你好","你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。"]], "status":200, "time":"2023-03-23 21:38:40" } --- api.py | 48 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/api.py b/api.py index 693c70a..ae30fd5 100644 --- a/api.py +++ b/api.py @@ -2,6 +2,7 @@ from fastapi import FastAPI, Request from transformers import AutoTokenizer, AutoModel import uvicorn, json, datetime import torch +from sse_starlette.sse import EventSourceResponse DEVICE = "cuda" DEVICE_ID = "0" @@ -18,17 +19,23 @@ def torch_gc(): app = FastAPI() -@app.post("/") -async def create_item(request: Request): - global model, tokenizer - json_post_raw = await request.json() - json_post = json.dumps(json_post_raw) - json_post_list = json.loads(json_post) - prompt = json_post_list.get('prompt') - history = json_post_list.get('history') - max_length = json_post_list.get('max_length') - top_p = json_post_list.get('top_p') - temperature = json_post_list.get('temperature') +def predict_stream(tokenizer, prompt, history, max_length, top_p, temperature): + for response, history in model.stream_chat(tokenizer, prompt, history, max_length=max_length, top_p=top_p, + temperature=temperature): + now = datetime.datetime.now() + time = now.strftime("%Y-%m-%d %H:%M:%S") + yield json.dumps({ + 'response': response, + 'history': history, + 'status': 200, + 'time': time + }) + log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' + print(log) + return torch_gc() + + +def predict(tokenizer, prompt, history, max_length, top_p, temperature): response, history = model.chat(tokenizer, prompt, history=history, @@ -48,6 +55,25 @@ async def create_item(request: Request): torch_gc() return answer +@app.post("/") +async def create_item(request: Request): + global model, tokenizer + json_post_raw = await request.json() + json_post = json.dumps(json_post_raw) + json_post_list = json.loads(json_post) + prompt = json_post_list.get('prompt') + history = json_post_list.get('history') + max_length = json_post_list.get('max_length') + top_p = json_post_list.get('top_p') + temperature = json_post_list.get('temperature') + stream = json_post_list.get('stream') + if stream: + res = predict_stream(tokenizer, prompt, history, max_length, top_p, temperature) + return EventSourceResponse(res) + else: + answer = predict(tokenizer, prompt, history, max_length, top_p, temperature) + return answer + if __name__ == '__main__': tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)