debug for web_demo_internlm

pull/478/head
YWMditto 2023-11-06 19:54:45 +08:00
parent b9c813a972
commit 94fdd178ba
1 changed files with 24 additions and 3 deletions

View File

@ -85,6 +85,7 @@ class SequenceGenerator:
top_p: float = 1.0, top_p: float = 1.0,
repetition_penalty: float = 1, repetition_penalty: float = 1,
length_penalty: float = 1.0, length_penalty: float = 1.0,
streaming=False
): ):
""" """
Args: Args:
@ -119,6 +120,7 @@ class SequenceGenerator:
length_penalty=length_penalty, # the penalty for length. if it > 1, then encourages long sequence. length_penalty=length_penalty, # the penalty for length. if it > 1, then encourages long sequence.
# Otherwise, encourages short sequence. # Otherwise, encourages short sequence.
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
streaming=streaming
) )
else: else:
return greedy_generate( return greedy_generate(
@ -132,6 +134,7 @@ class SequenceGenerator:
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
length_penalty=length_penalty, length_penalty=length_penalty,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
streaming=streaming
) )
@ -147,6 +150,7 @@ def greedy_generate(
repetition_penalty=1, repetition_penalty=1,
length_penalty=1.0, length_penalty=1.0,
bos_token_id=1, bos_token_id=1,
streaming=False,
feat_mask=None, feat_mask=None,
ffn_mask=None, ffn_mask=None,
layer_mask=None, layer_mask=None,
@ -179,6 +183,7 @@ def greedy_generate(
length_penalty=length_penalty, length_penalty=length_penalty,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
streaming=streaming,
feat_mask=feat_mask, feat_mask=feat_mask,
ffn_mask=ffn_mask, ffn_mask=ffn_mask,
layer_mask=layer_mask, layer_mask=layer_mask,
@ -199,6 +204,7 @@ def greedy_generate(
length_penalty=length_penalty, length_penalty=length_penalty,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
# streaming=streaming,
feat_mask=feat_mask, feat_mask=feat_mask,
ffn_mask=ffn_mask, ffn_mask=ffn_mask,
layer_mask=layer_mask, layer_mask=layer_mask,
@ -222,6 +228,7 @@ def sample_generate(
repetition_penalty=1.0, repetition_penalty=1.0,
length_penalty=1.0, length_penalty=1.0,
bos_token_id=1, bos_token_id=1,
streaming=False,
): ):
""" """
generate sequence in sampling way. generate sequence in sampling way.
@ -255,6 +262,7 @@ def sample_generate(
length_penalty=length_penalty, length_penalty=length_penalty,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
streaming=streaming,
) )
else: else:
token_ids = _beam_search_generate( token_ids = _beam_search_generate(
@ -272,6 +280,7 @@ def sample_generate(
length_penalty=length_penalty, length_penalty=length_penalty,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
# streaming=streaming,
) )
return token_ids return token_ids
@ -291,11 +300,11 @@ def _no_beam_search_generate(
length_penalty=1.0, length_penalty=1.0,
pad_token_id=0, pad_token_id=0,
bos_token_id=1, bos_token_id=1,
streaming=False,
feat_mask=None, feat_mask=None,
ffn_mask=None, ffn_mask=None,
layer_mask=None, layer_mask=None,
): ):
# delete num_return_sequences=1 for lint check;
batch_size = tokens.size(0) batch_size = tokens.size(0)
if eos_token_id is None: if eos_token_id is None:
_eos_token_id = -1 _eos_token_id = -1
@ -356,8 +365,14 @@ def _no_beam_search_generate(
inference_params.sequence_len_offset += tokens.size(1) inference_params.sequence_len_offset += tokens.size(1)
if _eos_token_id != -1: if _eos_token_id != -1:
scores[:, _eos_token_id] = -1e12 scores[:, _eos_token_id] = -1e12
# The first token generated.
next_tokens = scores.argmax(dim=-1, keepdim=True) next_tokens = scores.argmax(dim=-1, keepdim=True)
print(f"Input tokens: {tokens}")
_generated_idx = 0
print(f"gt - {_generated_idx} | ", next_tokens); _generated_idx += 1
token_ids = torch.cat([tokens, next_tokens], dim=1) token_ids = torch.cat([tokens, next_tokens], dim=1)
if streaming:
yield token_ids
cur_len = token_ids.size(1) cur_len = token_ids.size(1)
dones = token_ids.new_zeros(batch_size).eq(1) dones = token_ids.new_zeros(batch_size).eq(1)
# tokens = tokens[:, -1:] # tokens = tokens[:, -1:]
@ -417,7 +432,8 @@ def _no_beam_search_generate(
lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
) )
scores.scatter_(dim=1, index=token_ids, src=token_scores) scores.scatter_(dim=1, index=token_ids, src=token_scores)
# scores: [bsz, vocab_size]
import pdb; pdb.set_trace()
if eos_token_id is not None and length_penalty != 1.0: if eos_token_id is not None and length_penalty != 1.0:
# batch_size x vocab_size # batch_size x vocab_size
token_scores = scores / cur_len**length_penalty token_scores = scores / cur_len**length_penalty
@ -444,7 +460,11 @@ def _no_beam_search_generate(
next_tokens = next_tokens.masked_fill(dones, pad_token_id) next_tokens = next_tokens.masked_fill(dones, pad_token_id)
tokens = next_tokens.unsqueeze(1) tokens = next_tokens.unsqueeze(1)
print(f"gt - {_generated_idx} | ", next_tokens); _generated_idx += 1
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len
if streaming:
yield token_ids
end_mask = next_tokens.eq(_eos_token_id) end_mask = next_tokens.eq(_eos_token_id)
dones = dones.__or__(end_mask) dones = dones.__or__(end_mask)
@ -461,7 +481,8 @@ def _no_beam_search_generate(
# token_ids[:, -1].masked_fill_(~dones, eos_token_id) # token_ids[:, -1].masked_fill_(~dones, eos_token_id)
# TODO Here we are simply adding an extra dimension for interface compatibility, but in the future it will need to # TODO Here we are simply adding an extra dimension for interface compatibility, but in the future it will need to
# be able to return multiple real results # be able to return multiple real results
return token_ids[:, None] if not streaming:
return token_ids[:, None]
@torch.no_grad() @torch.no_grad()