增加SSE流式输出

发送请求时,stream参数为1请求流式输出。
发送请求格式如下:
curl -X POST "http://127.0.0.1:8000" \
     -H 'Content-Type: application/json' \
     -d '{"prompt": "你好", "history": [], "stream":1}'

回复按照SSE流式输出格式
data:{
  "response":"你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。",
  "history":[["你好","你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。"]],
  "status":200,
  "time":"2023-03-23 21:38:40"
}
pull/679/head
littlestone0806 2023-04-18 16:16:43 +08:00 committed by GitHub
parent c6790a09f0
commit 35f45dcf1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 37 additions and 11 deletions

48
api.py
View File

@ -2,6 +2,7 @@ from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime
import torch
from sse_starlette.sse import EventSourceResponse
DEVICE = "cuda"
DEVICE_ID = "0"
@ -18,17 +19,23 @@ def torch_gc():
app = FastAPI()
@app.post("/")
async def create_item(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = json_post_list.get('history')
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
def predict_stream(tokenizer, prompt, history, max_length, top_p, temperature):
for response, history in model.stream_chat(tokenizer, prompt, history, max_length=max_length, top_p=top_p,
temperature=temperature):
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
yield json.dumps({
'response': response,
'history': history,
'status': 200,
'time': time
})
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
return torch_gc()
def predict(tokenizer, prompt, history, max_length, top_p, temperature):
response, history = model.chat(tokenizer,
prompt,
history=history,
@ -48,6 +55,25 @@ async def create_item(request: Request):
torch_gc()
return answer
@app.post("/")
async def create_item(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = json_post_list.get('history')
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
stream = json_post_list.get('stream')
if stream:
res = predict_stream(tokenizer, prompt, history, max_length, top_p, temperature)
return EventSourceResponse(res)
else:
answer = predict(tokenizer, prompt, history, max_length, top_p, temperature)
return answer
if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)