mirror of https://github.com/InternLM/InternLM
update apis/inference.py
parent
0fb8dbab3a
commit
47c82aa223
|
@ -162,40 +162,26 @@ class SequenceGenerator:
|
|||
repetition_penalty: float = 1,
|
||||
length_penalty: float = 1.0,
|
||||
):
|
||||
if do_sample:
|
||||
yield from _streaming_no_beam_search_generate(
|
||||
self.decoder,
|
||||
tokens=tokens,
|
||||
max_length=max_length,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
eos_token_id=self.eos_token_id,
|
||||
additional_eos_token_list=self.additional_eos_token_list,
|
||||
add_eos_when_return=self.add_eos_when_return,
|
||||
do_sample=True,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
pad_token_id=self.pad_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
)
|
||||
else:
|
||||
yield from _streaming_no_beam_search_generate(
|
||||
self.decoder,
|
||||
tokens=tokens,
|
||||
max_length=max_length,
|
||||
temperature=1,
|
||||
top_k=50,
|
||||
top_p=1,
|
||||
eos_token_id=self.eos_token_id,
|
||||
additional_eos_token_list=self.additional_eos_token_list,
|
||||
add_eos_when_return=self.add_eos_when_return,
|
||||
do_sample=False,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
pad_token_id=self.pad_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
)
|
||||
if not do_sample:
|
||||
temperature = (1,)
|
||||
top_k = (50,)
|
||||
top_p = (1,)
|
||||
yield from _streaming_no_beam_search_generate(
|
||||
self.decoder,
|
||||
tokens=tokens,
|
||||
max_length=max_length,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
eos_token_id=self.eos_token_id,
|
||||
additional_eos_token_list=self.additional_eos_token_list,
|
||||
add_eos_when_return=self.add_eos_when_return,
|
||||
do_sample=do_sample,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
pad_token_id=self.pad_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
Loading…
Reference in New Issue