From 47c82aa223927a76270388ef0e651d395cdaf115 Mon Sep 17 00:00:00 2001 From: YWMditto <862779238@qq.com> Date: Thu, 9 Nov 2023 16:17:32 +0800 Subject: [PATCH] update apis/inference.py --- internlm/apis/inference.py | 54 ++++++++++++++------------------------ 1 file changed, 20 insertions(+), 34 deletions(-) diff --git a/internlm/apis/inference.py b/internlm/apis/inference.py index 9f60115..d07a25f 100644 --- a/internlm/apis/inference.py +++ b/internlm/apis/inference.py @@ -162,40 +162,26 @@ class SequenceGenerator: repetition_penalty: float = 1, length_penalty: float = 1.0, ): - if do_sample: - yield from _streaming_no_beam_search_generate( - self.decoder, - tokens=tokens, - max_length=max_length, - temperature=temperature, - top_k=top_k, - top_p=top_p, - eos_token_id=self.eos_token_id, - additional_eos_token_list=self.additional_eos_token_list, - add_eos_when_return=self.add_eos_when_return, - do_sample=True, - repetition_penalty=repetition_penalty, - length_penalty=length_penalty, - pad_token_id=self.pad_token_id, - bos_token_id=self.bos_token_id, - ) - else: - yield from _streaming_no_beam_search_generate( - self.decoder, - tokens=tokens, - max_length=max_length, - temperature=1, - top_k=50, - top_p=1, - eos_token_id=self.eos_token_id, - additional_eos_token_list=self.additional_eos_token_list, - add_eos_when_return=self.add_eos_when_return, - do_sample=False, - repetition_penalty=repetition_penalty, - length_penalty=length_penalty, - pad_token_id=self.pad_token_id, - bos_token_id=self.bos_token_id, - ) + if not do_sample: + temperature = (1,) + top_k = (50,) + top_p = (1,) + yield from _streaming_no_beam_search_generate( + self.decoder, + tokens=tokens, + max_length=max_length, + temperature=temperature, + top_k=top_k, + top_p=top_p, + eos_token_id=self.eos_token_id, + additional_eos_token_list=self.additional_eos_token_list, + add_eos_when_return=self.add_eos_when_return, + do_sample=do_sample, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + pad_token_id=self.pad_token_id, + bos_token_id=self.bos_token_id, + ) @torch.no_grad()