mirror of https://github.com/THUDM/ChatGLM2-6B
add support of max_tokens, temperature, top_p to openai_api.py
parent
8673270a4a
commit
12d8200b1d
|
@ -56,7 +56,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
max_length: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
||||
|
@ -105,10 +105,20 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||
|
||||
if request.stream:
|
||||
generate = predict(query, history, request.model)
|
||||
generate = predict(query,
|
||||
history,
|
||||
request.model,
|
||||
max_length=request.max_tokens if request.max_tokens else 2048,
|
||||
top_p=request.top_p if request.top_p else 0.7,
|
||||
temperature=request.temperature if request.temperature else 0.95)
|
||||
return StreamingResponse(generate, media_type="text/event-stream")
|
||||
|
||||
response, _ = model.chat(tokenizer, query, history=history)
|
||||
response, _ = model.chat(tokenizer,
|
||||
query,
|
||||
history=history,
|
||||
max_length=request.max_tokens if request.max_tokens else 2048,
|
||||
top_p=request.top_p if request.top_p else 0.7,
|
||||
temperature=request.temperature if request.temperature else 0.95)
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=response),
|
||||
|
@ -118,7 +128,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
|
||||
|
||||
|
||||
async def predict(query: str, history: List[List[str]], model_id: str):
|
||||
async def predict(query: str, history: List[List[str]], model_id: str, max_length: int, top_p: float, temperature: float):
|
||||
global model, tokenizer
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
|
@ -131,7 +141,12 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
|||
|
||||
current_length = 0
|
||||
|
||||
for new_response, _ in model.stream_chat(tokenizer, query, history):
|
||||
for new_response, _ in model.stream_chat(tokenizer,
|
||||
query,
|
||||
history,
|
||||
max_length=max_length,
|
||||
top_p=top_p,
|
||||
temperature=temperature):
|
||||
if len(new_response) == current_length:
|
||||
continue
|
||||
|
||||
|
@ -153,6 +168,7 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
|||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue