From 12d8200b1df20301c0c7bd226c9132001e1e1001 Mon Sep 17 00:00:00 2001 From: VoidIsVoid <343750470@qq.com> Date: Tue, 27 Jun 2023 19:04:44 +0800 Subject: [PATCH] add support of max_tokens, temperature, top_p to openai_api.py --- openai_api.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/openai_api.py b/openai_api.py index cc82967..c0d54cc 100644 --- a/openai_api.py +++ b/openai_api.py @@ -56,7 +56,7 @@ class ChatCompletionRequest(BaseModel): messages: List[ChatMessage] temperature: Optional[float] = None top_p: Optional[float] = None - max_length: Optional[int] = None + max_tokens: Optional[int] = None stream: Optional[bool] = False @@ -105,10 +105,20 @@ async def create_chat_completion(request: ChatCompletionRequest): history.append([prev_messages[i].content, prev_messages[i+1].content]) if request.stream: - generate = predict(query, history, request.model) + generate = predict(query, + history, + request.model, + max_length=request.max_tokens if request.max_tokens else 2048, + top_p=request.top_p if request.top_p else 0.7, + temperature=request.temperature if request.temperature else 0.95) return StreamingResponse(generate, media_type="text/event-stream") - response, _ = model.chat(tokenizer, query, history=history) + response, _ = model.chat(tokenizer, + query, + history=history, + max_length=request.max_tokens if request.max_tokens else 2048, + top_p=request.top_p if request.top_p else 0.7, + temperature=request.temperature if request.temperature else 0.95) choice_data = ChatCompletionResponseChoice( index=0, message=ChatMessage(role="assistant", content=response), @@ -118,7 +128,7 @@ async def create_chat_completion(request: ChatCompletionRequest): return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") -async def predict(query: str, history: List[List[str]], model_id: str): +async def predict(query: str, history: List[List[str]], model_id: str, max_length: int, top_p: float, temperature: float): global model, tokenizer choice_data = ChatCompletionResponseStreamChoice( @@ -131,7 +141,12 @@ async def predict(query: str, history: List[List[str]], model_id: str): current_length = 0 - for new_response, _ in model.stream_chat(tokenizer, query, history): + for new_response, _ in model.stream_chat(tokenizer, + query, + history, + max_length=max_length, + top_p=top_p, + temperature=temperature): if len(new_response) == current_length: continue @@ -153,6 +168,7 @@ async def predict(query: str, history: List[List[str]], model_id: str): ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + yield "data: [DONE]\n\n" if __name__ == "__main__":