mirror of https://github.com/THUDM/ChatGLM2-6B
fix: 返回json缺少usage
parent
80602dcae1
commit
ae180ba8b8
|
@ -5,6 +5,8 @@
|
||||||
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
import torch
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
@ -34,6 +36,7 @@ app.add_middleware(
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelCard(BaseModel):
|
class ModelCard(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
object: str = "model"
|
object: str = "model"
|
||||||
|
@ -85,6 +88,7 @@ class ChatCompletionResponse(BaseModel):
|
||||||
object: Literal["chat.completion", "chat.completion.chunk"]
|
object: Literal["chat.completion", "chat.completion.chunk"]
|
||||||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||||
|
usage: dict
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models", response_model=ModelList)
|
@app.get("/v1/models", response_model=ModelList)
|
||||||
|
@ -122,8 +126,16 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
message=ChatMessage(role="assistant", content=response),
|
message=ChatMessage(role="assistant", content=response),
|
||||||
finish_reason="stop"
|
finish_reason="stop"
|
||||||
)
|
)
|
||||||
|
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
|
pt = len(encoding.encode(query))
|
||||||
|
rt = len(encoding.encode(response))
|
||||||
|
usage_data = {
|
||||||
|
"prompt_tokens": pt,
|
||||||
|
"completion_tokens": rt,
|
||||||
|
"total_tokens": pt + rt
|
||||||
|
}
|
||||||
|
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion",
|
||||||
|
usage=usage_data)
|
||||||
|
|
||||||
|
|
||||||
async def predict(query: str, history: List[List[str]], model_id: str):
|
async def predict(query: str, history: List[List[str]], model_id: str):
|
||||||
|
@ -154,7 +166,6 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
||||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||||
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0,
|
||||||
delta=DeltaMessage(),
|
delta=DeltaMessage(),
|
||||||
|
@ -165,7 +176,6 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
||||||
yield '[DONE]'
|
yield '[DONE]'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
|
||||||
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
|
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
|
||||||
|
|
|
@ -8,3 +8,4 @@ sentencepiece
|
||||||
accelerate
|
accelerate
|
||||||
sse-starlette
|
sse-starlette
|
||||||
streamlit>=1.24.0
|
streamlit>=1.24.0
|
||||||
|
tiktoken
|
Loading…
Reference in New Issue