|
|
|
@ -11,9 +11,9 @@ from pydantic import BaseModel, Field
|
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
|
from starlette.responses import StreamingResponse
|
|
|
|
|
from typing import Any, Dict, List, Literal, Optional, Union
|
|
|
|
|
from transformers import AutoTokenizer, AutoModel
|
|
|
|
|
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager
|
|
|
|
@ -114,7 +114,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
|
|
|
|
|
|
|
|
if request.stream:
|
|
|
|
|
generate = predict(query, history, request.model)
|
|
|
|
|
return StreamingResponse(generate, media_type="text/event-stream")
|
|
|
|
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
|
|
|
|
|
|
|
|
|
response, _ = model.chat(tokenizer, query, history=history)
|
|
|
|
|
choice_data = ChatCompletionResponseChoice(
|
|
|
|
@ -135,7 +135,7 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
|
|
|
|
finish_reason=None
|
|
|
|
|
)
|
|
|
|
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
|
|
|
|
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
|
|
|
|
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
|
|
|
|
|
|
|
|
|
current_length = 0
|
|
|
|
|
|
|
|
|
@ -152,7 +152,8 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
|
|
|
|
finish_reason=None
|
|
|
|
|
)
|
|
|
|
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
|
|
|
|
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
|
|
|
|
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
choice_data = ChatCompletionResponseStreamChoice(
|
|
|
|
|
index=0,
|
|
|
|
@ -160,7 +161,9 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
|
|
|
|
finish_reason="stop"
|
|
|
|
|
)
|
|
|
|
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
|
|
|
|
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
|
|
|
|
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
|
|
|
|
yield '[DONE]'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|