add support of max_tokens, temperature, top_p to openai_api.py

pull/78/head
VoidIsVoid 2023-06-27 19:04:44 +08:00 committed by GitHub
parent 8673270a4a
commit 12d8200b1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 5 deletions

View File

@ -56,7 +56,7 @@ class ChatCompletionRequest(BaseModel):
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_length: Optional[int] = None
max_tokens: Optional[int] = None
stream: Optional[bool] = False
@ -105,10 +105,20 @@ async def create_chat_completion(request: ChatCompletionRequest):
history.append([prev_messages[i].content, prev_messages[i+1].content])
if request.stream:
generate = predict(query, history, request.model)
generate = predict(query,
history,
request.model,
max_length=request.max_tokens if request.max_tokens else 2048,
top_p=request.top_p if request.top_p else 0.7,
temperature=request.temperature if request.temperature else 0.95)
return StreamingResponse(generate, media_type="text/event-stream")
response, _ = model.chat(tokenizer, query, history=history)
response, _ = model.chat(tokenizer,
query,
history=history,
max_length=request.max_tokens if request.max_tokens else 2048,
top_p=request.top_p if request.top_p else 0.7,
temperature=request.temperature if request.temperature else 0.95)
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
@ -118,7 +128,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
async def predict(query: str, history: List[List[str]], model_id: str):
async def predict(query: str, history: List[List[str]], model_id: str, max_length: int, top_p: float, temperature: float):
global model, tokenizer
choice_data = ChatCompletionResponseStreamChoice(
@ -131,7 +141,12 @@ async def predict(query: str, history: List[List[str]], model_id: str):
current_length = 0
for new_response, _ in model.stream_chat(tokenizer, query, history):
for new_response, _ in model.stream_chat(tokenizer,
query,
history,
max_length=max_length,
top_p=top_p,
temperature=temperature):
if len(new_response) == current_length:
continue
@ -153,6 +168,7 @@ async def predict(query: str, history: List[List[str]], model_id: str):
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "data: [DONE]\n\n"
if __name__ == "__main__":