mirror of https://github.com/THUDM/ChatGLM2-6B
add support of max_tokens, temperature, top_p to openai_api.py
parent
12d8200b1d
commit
c3930e9832
|
@ -56,6 +56,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
max_length: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
@ -104,21 +105,28 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
|
||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||
|
||||
chat_kwargs = {}
|
||||
if request.max_tokens is not None:
|
||||
chat_kwargs['max_new_tokens'] = request.max_tokens
|
||||
elif request.max_length:
|
||||
chat_kwargs['max_length'] = request.max_length
|
||||
|
||||
if request.stream:
|
||||
generate = predict(query,
|
||||
history,
|
||||
request.model,
|
||||
max_length=request.max_tokens if request.max_tokens else 2048,
|
||||
generate = predict(query,
|
||||
history,
|
||||
request.model,
|
||||
top_p=request.top_p if request.top_p else 0.7,
|
||||
temperature=request.temperature if request.temperature else 0.95)
|
||||
temperature=request.temperature if request.temperature else 0.95,
|
||||
**chat_kwargs)
|
||||
return StreamingResponse(generate, media_type="text/event-stream")
|
||||
|
||||
response, _ = model.chat(tokenizer,
|
||||
query,
|
||||
history=history,
|
||||
max_length=request.max_tokens if request.max_tokens else 2048,
|
||||
response, _ = model.chat(tokenizer,
|
||||
query,
|
||||
history=history,
|
||||
top_p=request.top_p if request.top_p else 0.7,
|
||||
temperature=request.temperature if request.temperature else 0.95)
|
||||
temperature=request.temperature if request.temperature else 0.95,
|
||||
**chat_kwargs)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=response),
|
||||
|
@ -128,7 +136,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
|
||||
|
||||
|
||||
async def predict(query: str, history: List[List[str]], model_id: str, max_length: int, top_p: float, temperature: float):
|
||||
async def predict(query: str, history: List[List[str]], model_id: str, top_p: float, temperature: float, **kwargs):
|
||||
global model, tokenizer
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
|
@ -141,12 +149,12 @@ async def predict(query: str, history: List[List[str]], model_id: str, max_lengt
|
|||
|
||||
current_length = 0
|
||||
|
||||
for new_response, _ in model.stream_chat(tokenizer,
|
||||
query,
|
||||
for new_response, _ in model.stream_chat(tokenizer,
|
||||
query,
|
||||
history,
|
||||
max_length=max_length,
|
||||
top_p=top_p,
|
||||
temperature=temperature):
|
||||
temperature=temperature,
|
||||
**kwargs):
|
||||
if len(new_response) == current_length:
|
||||
continue
|
||||
|
||||
|
|
Loading…
Reference in New Issue