add support of max_tokens, temperature, top_p to openai_api.py

pull/78/head
VoidIsVoid 2023-06-28 10:11:05 +08:00 committed by GitHub
parent 12d8200b1d
commit c3930e9832
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 23 additions and 15 deletions

View File

@ -56,6 +56,7 @@ class ChatCompletionRequest(BaseModel):
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_length: Optional[int] = None
max_tokens: Optional[int] = None
stream: Optional[bool] = False
@ -104,21 +105,28 @@ async def create_chat_completion(request: ChatCompletionRequest):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i+1].content])
chat_kwargs = {}
if request.max_tokens is not None:
chat_kwargs['max_new_tokens'] = request.max_tokens
elif request.max_length:
chat_kwargs['max_length'] = request.max_length
if request.stream:
generate = predict(query,
history,
request.model,
max_length=request.max_tokens if request.max_tokens else 2048,
generate = predict(query,
history,
request.model,
top_p=request.top_p if request.top_p else 0.7,
temperature=request.temperature if request.temperature else 0.95)
temperature=request.temperature if request.temperature else 0.95,
**chat_kwargs)
return StreamingResponse(generate, media_type="text/event-stream")
response, _ = model.chat(tokenizer,
query,
history=history,
max_length=request.max_tokens if request.max_tokens else 2048,
response, _ = model.chat(tokenizer,
query,
history=history,
top_p=request.top_p if request.top_p else 0.7,
temperature=request.temperature if request.temperature else 0.95)
temperature=request.temperature if request.temperature else 0.95,
**chat_kwargs)
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
@ -128,7 +136,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
async def predict(query: str, history: List[List[str]], model_id: str, max_length: int, top_p: float, temperature: float):
async def predict(query: str, history: List[List[str]], model_id: str, top_p: float, temperature: float, **kwargs):
global model, tokenizer
choice_data = ChatCompletionResponseStreamChoice(
@ -141,12 +149,12 @@ async def predict(query: str, history: List[List[str]], model_id: str, max_lengt
current_length = 0
for new_response, _ in model.stream_chat(tokenizer,
query,
for new_response, _ in model.stream_chat(tokenizer,
query,
history,
max_length=max_length,
top_p=top_p,
temperature=temperature):
temperature=temperature,
**kwargs):
if len(new_response) == current_length:
continue