From ae180ba8b8e87bf6606d5044713a7d0bf7825acb Mon Sep 17 00:00:00 2001 From: zkz098 Date: Sun, 20 Aug 2023 22:41:45 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E8=BF=94=E5=9B=9Ejson=E7=BC=BA=E5=B0=91?= =?UTF-8?q?usage?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- openai_api.py | 24 +++++++++++++++++------- requirements.txt | 3 ++- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/openai_api.py b/openai_api.py index 7225562..4cf5fda 100644 --- a/openai_api.py +++ b/openai_api.py @@ -5,6 +5,8 @@ import time + +import tiktoken import torch import uvicorn from pydantic import BaseModel, Field @@ -17,7 +19,7 @@ from sse_starlette.sse import ServerSentEvent, EventSourceResponse @asynccontextmanager -async def lifespan(app: FastAPI): # collects GPU memory +async def lifespan(app: FastAPI): # collects GPU memory yield if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -34,6 +36,7 @@ app.add_middleware( allow_headers=["*"], ) + class ModelCard(BaseModel): id: str object: str = "model" @@ -85,6 +88,7 @@ class ChatCompletionResponse(BaseModel): object: Literal["chat.completion", "chat.completion.chunk"] choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] created: Optional[int] = Field(default_factory=lambda: int(time.time())) + usage: dict @app.get("/v1/models", response_model=ModelList) @@ -109,8 +113,8 @@ async def create_chat_completion(request: ChatCompletionRequest): history = [] if len(prev_messages) % 2 == 0: for i in range(0, len(prev_messages), 2): - if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": - history.append([prev_messages[i].content, prev_messages[i+1].content]) + if prev_messages[i].role == "user" and prev_messages[i + 1].role == "assistant": + history.append([prev_messages[i].content, prev_messages[i + 1].content]) if request.stream: generate = predict(query, history, request.model) @@ -122,8 +126,16 @@ async def create_chat_completion(request: ChatCompletionRequest): message=ChatMessage(role="assistant", content=response), finish_reason="stop" ) - - return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") + encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") + pt = len(encoding.encode(query)) + rt = len(encoding.encode(response)) + usage_data = { + "prompt_tokens": pt, + "completion_tokens": rt, + "total_tokens": pt + rt + } + return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion", + usage=usage_data) async def predict(query: str, history: List[List[str]], model_id: str): @@ -154,7 +166,6 @@ async def predict(query: str, history: List[List[str]], model_id: str): chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) - choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(), @@ -165,7 +176,6 @@ async def predict(query: str, history: List[List[str]], model_id: str): yield '[DONE]' - if __name__ == "__main__": tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda() diff --git a/requirements.txt b/requirements.txt index 265b8eb..1d43416 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ mdtex2html sentencepiece accelerate sse-starlette -streamlit>=1.24.0 \ No newline at end of file +streamlit>=1.24.0 +tiktoken \ No newline at end of file