update apis/inference.py

pull/478/head
YWMditto 2023-11-09 16:17:32 +08:00
parent 0fb8dbab3a
commit 47c82aa223
1 changed files with 20 additions and 34 deletions

View File

@ -162,40 +162,26 @@ class SequenceGenerator:
repetition_penalty: float = 1, repetition_penalty: float = 1,
length_penalty: float = 1.0, length_penalty: float = 1.0,
): ):
if do_sample: if not do_sample:
yield from _streaming_no_beam_search_generate( temperature = (1,)
self.decoder, top_k = (50,)
tokens=tokens, top_p = (1,)
max_length=max_length, yield from _streaming_no_beam_search_generate(
temperature=temperature, self.decoder,
top_k=top_k, tokens=tokens,
top_p=top_p, max_length=max_length,
eos_token_id=self.eos_token_id, temperature=temperature,
additional_eos_token_list=self.additional_eos_token_list, top_k=top_k,
add_eos_when_return=self.add_eos_when_return, top_p=top_p,
do_sample=True, eos_token_id=self.eos_token_id,
repetition_penalty=repetition_penalty, additional_eos_token_list=self.additional_eos_token_list,
length_penalty=length_penalty, add_eos_when_return=self.add_eos_when_return,
pad_token_id=self.pad_token_id, do_sample=do_sample,
bos_token_id=self.bos_token_id, repetition_penalty=repetition_penalty,
) length_penalty=length_penalty,
else: pad_token_id=self.pad_token_id,
yield from _streaming_no_beam_search_generate( bos_token_id=self.bos_token_id,
self.decoder, )
tokens=tokens,
max_length=max_length,
temperature=1,
top_k=50,
top_p=1,
eos_token_id=self.eos_token_id,
additional_eos_token_list=self.additional_eos_token_list,
add_eos_when_return=self.add_eos_when_return,
do_sample=False,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
pad_token_id=self.pad_token_id,
bos_token_id=self.bos_token_id,
)
@torch.no_grad() @torch.no_grad()