update apis/inference.py

pull/478/head
YWMditto 2023-11-09 16:17:32 +08:00
parent 0fb8dbab3a
commit 47c82aa223
1 changed files with 20 additions and 34 deletions

View File

@ -162,7 +162,10 @@ class SequenceGenerator:
repetition_penalty: float = 1,
length_penalty: float = 1.0,
):
if do_sample:
if not do_sample:
temperature = (1,)
top_k = (50,)
top_p = (1,)
yield from _streaming_no_beam_search_generate(
self.decoder,
tokens=tokens,
@ -173,24 +176,7 @@ class SequenceGenerator:
eos_token_id=self.eos_token_id,
additional_eos_token_list=self.additional_eos_token_list,
add_eos_when_return=self.add_eos_when_return,
do_sample=True,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
pad_token_id=self.pad_token_id,
bos_token_id=self.bos_token_id,
)
else:
yield from _streaming_no_beam_search_generate(
self.decoder,
tokens=tokens,
max_length=max_length,
temperature=1,
top_k=50,
top_p=1,
eos_token_id=self.eos_token_id,
additional_eos_token_list=self.additional_eos_token_list,
add_eos_when_return=self.add_eos_when_return,
do_sample=False,
do_sample=do_sample,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
pad_token_id=self.pad_token_id,