fix: openai_api 的 stream api,服务端全部生成文本后客户端才一次性收到

pull/152/head
mougua 2023-07-02 12:38:44 +08:00
parent 53f0106817
commit fcd2d7f4bb
2 changed files with 10 additions and 6 deletions

View File

@ -11,9 +11,9 @@ from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from starlette.responses import StreamingResponse
from typing import Any, Dict, List, Literal, Optional, Union
from transformers import AutoTokenizer, AutoModel
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
@asynccontextmanager
@ -114,7 +114,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
if request.stream:
generate = predict(query, history, request.model)
return StreamingResponse(generate, media_type="text/event-stream")
return EventSourceResponse(generate, media_type="text/event-stream")
response, _ = model.chat(tokenizer, query, history=history)
choice_data = ChatCompletionResponseChoice(
@ -135,7 +135,7 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
current_length = 0
@ -152,7 +152,8 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
@ -160,7 +161,9 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield '[DONE]'
if __name__ == "__main__":

View File

@ -5,4 +5,5 @@ torch>=2.0
gradio
mdtex2html
sentencepiece
accelerate
accelerate
sse-starlette