增加SSE流式输出

发送请求时,stream参数为1请求流式输出。
发送请求格式如下:
curl -X POST "http://127.0.0.1:8000" \
     -H 'Content-Type: application/json' \
     -d '{"prompt": "你好", "history": [], "stream":1}'

回复按照SSE流式输出格式
data:{
  "response":"你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。",
  "history":[["你好","你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。"]],
  "status":200,
  "time":"2023-03-23 21:38:40"
}
pull/679/head
littlestone0806 2023-04-18 16:16:43 +08:00 committed by GitHub
parent c6790a09f0
commit 35f45dcf1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 37 additions and 11 deletions

48
api.py
View File

@ -2,6 +2,7 @@ from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime import uvicorn, json, datetime
import torch import torch
from sse_starlette.sse import EventSourceResponse
DEVICE = "cuda" DEVICE = "cuda"
DEVICE_ID = "0" DEVICE_ID = "0"
@ -18,17 +19,23 @@ def torch_gc():
app = FastAPI() app = FastAPI()
@app.post("/") def predict_stream(tokenizer, prompt, history, max_length, top_p, temperature):
async def create_item(request: Request): for response, history in model.stream_chat(tokenizer, prompt, history, max_length=max_length, top_p=top_p,
global model, tokenizer temperature=temperature):
json_post_raw = await request.json() now = datetime.datetime.now()
json_post = json.dumps(json_post_raw) time = now.strftime("%Y-%m-%d %H:%M:%S")
json_post_list = json.loads(json_post) yield json.dumps({
prompt = json_post_list.get('prompt') 'response': response,
history = json_post_list.get('history') 'history': history,
max_length = json_post_list.get('max_length') 'status': 200,
top_p = json_post_list.get('top_p') 'time': time
temperature = json_post_list.get('temperature') })
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
return torch_gc()
def predict(tokenizer, prompt, history, max_length, top_p, temperature):
response, history = model.chat(tokenizer, response, history = model.chat(tokenizer,
prompt, prompt,
history=history, history=history,
@ -48,6 +55,25 @@ async def create_item(request: Request):
torch_gc() torch_gc()
return answer return answer
@app.post("/")
async def create_item(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = json_post_list.get('history')
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
stream = json_post_list.get('stream')
if stream:
res = predict_stream(tokenizer, prompt, history, max_length, top_p, temperature)
return EventSourceResponse(res)
else:
answer = predict(tokenizer, prompt, history, max_length, top_p, temperature)
return answer
if __name__ == '__main__': if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)