From 94fdd178ba2e864cca5d07ab6d7f1ce9afcadebe Mon Sep 17 00:00:00 2001 From: YWMditto <862779238@qq.com> Date: Mon, 6 Nov 2023 19:54:45 +0800 Subject: [PATCH] debug for web_demo_internlm --- internlm/apis/inference.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/internlm/apis/inference.py b/internlm/apis/inference.py index 88d6d50..a7a361e 100644 --- a/internlm/apis/inference.py +++ b/internlm/apis/inference.py @@ -85,6 +85,7 @@ class SequenceGenerator: top_p: float = 1.0, repetition_penalty: float = 1, length_penalty: float = 1.0, + streaming=False ): """ Args: @@ -119,6 +120,7 @@ class SequenceGenerator: length_penalty=length_penalty, # the penalty for length. if it > 1, then encourages long sequence. # Otherwise, encourages short sequence. bos_token_id=self.bos_token_id, + streaming=streaming ) else: return greedy_generate( @@ -132,6 +134,7 @@ class SequenceGenerator: repetition_penalty=repetition_penalty, length_penalty=length_penalty, bos_token_id=self.bos_token_id, + streaming=streaming ) @@ -147,6 +150,7 @@ def greedy_generate( repetition_penalty=1, length_penalty=1.0, bos_token_id=1, + streaming=False, feat_mask=None, ffn_mask=None, layer_mask=None, @@ -179,6 +183,7 @@ def greedy_generate( length_penalty=length_penalty, pad_token_id=pad_token_id, bos_token_id=bos_token_id, + streaming=streaming, feat_mask=feat_mask, ffn_mask=ffn_mask, layer_mask=layer_mask, @@ -199,6 +204,7 @@ def greedy_generate( length_penalty=length_penalty, pad_token_id=pad_token_id, bos_token_id=bos_token_id, + # streaming=streaming, feat_mask=feat_mask, ffn_mask=ffn_mask, layer_mask=layer_mask, @@ -222,6 +228,7 @@ def sample_generate( repetition_penalty=1.0, length_penalty=1.0, bos_token_id=1, + streaming=False, ): """ generate sequence in sampling way. @@ -255,6 +262,7 @@ def sample_generate( length_penalty=length_penalty, pad_token_id=pad_token_id, bos_token_id=bos_token_id, + streaming=streaming, ) else: token_ids = _beam_search_generate( @@ -272,6 +280,7 @@ def sample_generate( length_penalty=length_penalty, pad_token_id=pad_token_id, bos_token_id=bos_token_id, + # streaming=streaming, ) return token_ids @@ -291,11 +300,11 @@ def _no_beam_search_generate( length_penalty=1.0, pad_token_id=0, bos_token_id=1, + streaming=False, feat_mask=None, ffn_mask=None, layer_mask=None, ): - # delete num_return_sequences=1 for lint check; batch_size = tokens.size(0) if eos_token_id is None: _eos_token_id = -1 @@ -356,8 +365,14 @@ def _no_beam_search_generate( inference_params.sequence_len_offset += tokens.size(1) if _eos_token_id != -1: scores[:, _eos_token_id] = -1e12 + # The first token generated. next_tokens = scores.argmax(dim=-1, keepdim=True) + print(f"Input tokens: {tokens}") + _generated_idx = 0 + print(f"gt - {_generated_idx} | ", next_tokens); _generated_idx += 1 token_ids = torch.cat([tokens, next_tokens], dim=1) + if streaming: + yield token_ids cur_len = token_ids.size(1) dones = token_ids.new_zeros(batch_size).eq(1) # tokens = tokens[:, -1:] @@ -417,7 +432,8 @@ def _no_beam_search_generate( lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores ) scores.scatter_(dim=1, index=token_ids, src=token_scores) - + # scores: [bsz, vocab_size] + import pdb; pdb.set_trace() if eos_token_id is not None and length_penalty != 1.0: # batch_size x vocab_size token_scores = scores / cur_len**length_penalty @@ -444,7 +460,11 @@ def _no_beam_search_generate( next_tokens = next_tokens.masked_fill(dones, pad_token_id) tokens = next_tokens.unsqueeze(1) + print(f"gt - {_generated_idx} | ", next_tokens); _generated_idx += 1 + token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len + if streaming: + yield token_ids end_mask = next_tokens.eq(_eos_token_id) dones = dones.__or__(end_mask) @@ -461,7 +481,8 @@ def _no_beam_search_generate( # token_ids[:, -1].masked_fill_(~dones, eos_token_id) # TODO Here we are simply adding an extra dimension for interface compatibility, but in the future it will need to # be able to return multiple real results - return token_ids[:, None] + if not streaming: + return token_ids[:, None] @torch.no_grad()