diff --git a/internlm/apis/inference.py b/internlm/apis/inference.py index 9f60115..d07a25f 100644 --- a/internlm/apis/inference.py +++ b/internlm/apis/inference.py @@ -162,40 +162,26 @@ class SequenceGenerator: repetition_penalty: float = 1, length_penalty: float = 1.0, ): - if do_sample: - yield from _streaming_no_beam_search_generate( - self.decoder, - tokens=tokens, - max_length=max_length, - temperature=temperature, - top_k=top_k, - top_p=top_p, - eos_token_id=self.eos_token_id, - additional_eos_token_list=self.additional_eos_token_list, - add_eos_when_return=self.add_eos_when_return, - do_sample=True, - repetition_penalty=repetition_penalty, - length_penalty=length_penalty, - pad_token_id=self.pad_token_id, - bos_token_id=self.bos_token_id, - ) - else: - yield from _streaming_no_beam_search_generate( - self.decoder, - tokens=tokens, - max_length=max_length, - temperature=1, - top_k=50, - top_p=1, - eos_token_id=self.eos_token_id, - additional_eos_token_list=self.additional_eos_token_list, - add_eos_when_return=self.add_eos_when_return, - do_sample=False, - repetition_penalty=repetition_penalty, - length_penalty=length_penalty, - pad_token_id=self.pad_token_id, - bos_token_id=self.bos_token_id, - ) + if not do_sample: + temperature = (1,) + top_k = (50,) + top_p = (1,) + yield from _streaming_no_beam_search_generate( + self.decoder, + tokens=tokens, + max_length=max_length, + temperature=temperature, + top_k=top_k, + top_p=top_p, + eos_token_id=self.eos_token_id, + additional_eos_token_list=self.additional_eos_token_list, + add_eos_when_return=self.add_eos_when_return, + do_sample=do_sample, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + pad_token_id=self.pad_token_id, + bos_token_id=self.bos_token_id, + ) @torch.no_grad()