From c3930e98323ff5cc19ba08a47c5a43fb78d88ec5 Mon Sep 17 00:00:00 2001 From: VoidIsVoid <343750470@qq.com> Date: Wed, 28 Jun 2023 10:11:05 +0800 Subject: [PATCH] add support of max_tokens, temperature, top_p to openai_api.py --- openai_api.py | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/openai_api.py b/openai_api.py index c0d54cc..83a18b6 100644 --- a/openai_api.py +++ b/openai_api.py @@ -56,6 +56,7 @@ class ChatCompletionRequest(BaseModel): messages: List[ChatMessage] temperature: Optional[float] = None top_p: Optional[float] = None + max_length: Optional[int] = None max_tokens: Optional[int] = None stream: Optional[bool] = False @@ -104,21 +105,28 @@ async def create_chat_completion(request: ChatCompletionRequest): if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": history.append([prev_messages[i].content, prev_messages[i+1].content]) + chat_kwargs = {} + if request.max_tokens is not None: + chat_kwargs['max_new_tokens'] = request.max_tokens + elif request.max_length: + chat_kwargs['max_length'] = request.max_length + if request.stream: - generate = predict(query, - history, - request.model, - max_length=request.max_tokens if request.max_tokens else 2048, + generate = predict(query, + history, + request.model, top_p=request.top_p if request.top_p else 0.7, - temperature=request.temperature if request.temperature else 0.95) + temperature=request.temperature if request.temperature else 0.95, + **chat_kwargs) return StreamingResponse(generate, media_type="text/event-stream") - response, _ = model.chat(tokenizer, - query, - history=history, - max_length=request.max_tokens if request.max_tokens else 2048, + response, _ = model.chat(tokenizer, + query, + history=history, top_p=request.top_p if request.top_p else 0.7, - temperature=request.temperature if request.temperature else 0.95) + temperature=request.temperature if request.temperature else 0.95, + **chat_kwargs) + choice_data = ChatCompletionResponseChoice( index=0, message=ChatMessage(role="assistant", content=response), @@ -128,7 +136,7 @@ async def create_chat_completion(request: ChatCompletionRequest): return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") -async def predict(query: str, history: List[List[str]], model_id: str, max_length: int, top_p: float, temperature: float): +async def predict(query: str, history: List[List[str]], model_id: str, top_p: float, temperature: float, **kwargs): global model, tokenizer choice_data = ChatCompletionResponseStreamChoice( @@ -141,12 +149,12 @@ async def predict(query: str, history: List[List[str]], model_id: str, max_lengt current_length = 0 - for new_response, _ in model.stream_chat(tokenizer, - query, + for new_response, _ in model.stream_chat(tokenizer, + query, history, - max_length=max_length, top_p=top_p, - temperature=temperature): + temperature=temperature, + **kwargs): if len(new_response) == current_length: continue