mirror of https://github.com/InternLM/InternLM
update apis/inference.py
parent
0fb8dbab3a
commit
47c82aa223
|
@ -162,40 +162,26 @@ class SequenceGenerator:
|
||||||
repetition_penalty: float = 1,
|
repetition_penalty: float = 1,
|
||||||
length_penalty: float = 1.0,
|
length_penalty: float = 1.0,
|
||||||
):
|
):
|
||||||
if do_sample:
|
if not do_sample:
|
||||||
yield from _streaming_no_beam_search_generate(
|
temperature = (1,)
|
||||||
self.decoder,
|
top_k = (50,)
|
||||||
tokens=tokens,
|
top_p = (1,)
|
||||||
max_length=max_length,
|
yield from _streaming_no_beam_search_generate(
|
||||||
temperature=temperature,
|
self.decoder,
|
||||||
top_k=top_k,
|
tokens=tokens,
|
||||||
top_p=top_p,
|
max_length=max_length,
|
||||||
eos_token_id=self.eos_token_id,
|
temperature=temperature,
|
||||||
additional_eos_token_list=self.additional_eos_token_list,
|
top_k=top_k,
|
||||||
add_eos_when_return=self.add_eos_when_return,
|
top_p=top_p,
|
||||||
do_sample=True,
|
eos_token_id=self.eos_token_id,
|
||||||
repetition_penalty=repetition_penalty,
|
additional_eos_token_list=self.additional_eos_token_list,
|
||||||
length_penalty=length_penalty,
|
add_eos_when_return=self.add_eos_when_return,
|
||||||
pad_token_id=self.pad_token_id,
|
do_sample=do_sample,
|
||||||
bos_token_id=self.bos_token_id,
|
repetition_penalty=repetition_penalty,
|
||||||
)
|
length_penalty=length_penalty,
|
||||||
else:
|
pad_token_id=self.pad_token_id,
|
||||||
yield from _streaming_no_beam_search_generate(
|
bos_token_id=self.bos_token_id,
|
||||||
self.decoder,
|
)
|
||||||
tokens=tokens,
|
|
||||||
max_length=max_length,
|
|
||||||
temperature=1,
|
|
||||||
top_k=50,
|
|
||||||
top_p=1,
|
|
||||||
eos_token_id=self.eos_token_id,
|
|
||||||
additional_eos_token_list=self.additional_eos_token_list,
|
|
||||||
add_eos_when_return=self.add_eos_when_return,
|
|
||||||
do_sample=False,
|
|
||||||
repetition_penalty=repetition_penalty,
|
|
||||||
length_penalty=length_penalty,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
bos_token_id=self.bos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
Loading…
Reference in New Issue