fix: openai_api 的 stream api,服务端全部生成文本后客户端才一次性收到

pull/152/head
mougua 1 year ago
parent 53f0106817
commit fcd2d7f4bb

@ -11,9 +11,9 @@ from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from starlette.responses import StreamingResponse
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
@asynccontextmanager @asynccontextmanager
@ -114,7 +114,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
if request.stream: if request.stream:
generate = predict(query, history, request.model) generate = predict(query, history, request.model)
return StreamingResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")
response, _ = model.chat(tokenizer, query, history=history) response, _ = model.chat(tokenizer, query, history=history)
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
@ -135,7 +135,7 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
current_length = 0 current_length = 0
@ -152,7 +152,8 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
@ -160,7 +161,9 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason="stop" finish_reason="stop"
) )
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield '[DONE]'
if __name__ == "__main__": if __name__ == "__main__":

@ -5,4 +5,5 @@ torch>=2.0
gradio gradio
mdtex2html mdtex2html
sentencepiece sentencepiece
accelerate accelerate
sse-starlette
Loading…
Cancel
Save