mirror of https://github.com/InternLM/InternLM
debug for web_demo_internlm
parent
b9c813a972
commit
94fdd178ba
|
@ -85,6 +85,7 @@ class SequenceGenerator:
|
|||
top_p: float = 1.0,
|
||||
repetition_penalty: float = 1,
|
||||
length_penalty: float = 1.0,
|
||||
streaming=False
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -119,6 +120,7 @@ class SequenceGenerator:
|
|||
length_penalty=length_penalty, # the penalty for length. if it > 1, then encourages long sequence.
|
||||
# Otherwise, encourages short sequence.
|
||||
bos_token_id=self.bos_token_id,
|
||||
streaming=streaming
|
||||
)
|
||||
else:
|
||||
return greedy_generate(
|
||||
|
@ -132,6 +134,7 @@ class SequenceGenerator:
|
|||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
bos_token_id=self.bos_token_id,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
|
||||
|
@ -147,6 +150,7 @@ def greedy_generate(
|
|||
repetition_penalty=1,
|
||||
length_penalty=1.0,
|
||||
bos_token_id=1,
|
||||
streaming=False,
|
||||
feat_mask=None,
|
||||
ffn_mask=None,
|
||||
layer_mask=None,
|
||||
|
@ -179,6 +183,7 @@ def greedy_generate(
|
|||
length_penalty=length_penalty,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
streaming=streaming,
|
||||
feat_mask=feat_mask,
|
||||
ffn_mask=ffn_mask,
|
||||
layer_mask=layer_mask,
|
||||
|
@ -199,6 +204,7 @@ def greedy_generate(
|
|||
length_penalty=length_penalty,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
# streaming=streaming,
|
||||
feat_mask=feat_mask,
|
||||
ffn_mask=ffn_mask,
|
||||
layer_mask=layer_mask,
|
||||
|
@ -222,6 +228,7 @@ def sample_generate(
|
|||
repetition_penalty=1.0,
|
||||
length_penalty=1.0,
|
||||
bos_token_id=1,
|
||||
streaming=False,
|
||||
):
|
||||
"""
|
||||
generate sequence in sampling way.
|
||||
|
@ -255,6 +262,7 @@ def sample_generate(
|
|||
length_penalty=length_penalty,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
streaming=streaming,
|
||||
)
|
||||
else:
|
||||
token_ids = _beam_search_generate(
|
||||
|
@ -272,6 +280,7 @@ def sample_generate(
|
|||
length_penalty=length_penalty,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
# streaming=streaming,
|
||||
)
|
||||
return token_ids
|
||||
|
||||
|
@ -291,11 +300,11 @@ def _no_beam_search_generate(
|
|||
length_penalty=1.0,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
streaming=False,
|
||||
feat_mask=None,
|
||||
ffn_mask=None,
|
||||
layer_mask=None,
|
||||
):
|
||||
# delete num_return_sequences=1 for lint check;
|
||||
batch_size = tokens.size(0)
|
||||
if eos_token_id is None:
|
||||
_eos_token_id = -1
|
||||
|
@ -356,8 +365,14 @@ def _no_beam_search_generate(
|
|||
inference_params.sequence_len_offset += tokens.size(1)
|
||||
if _eos_token_id != -1:
|
||||
scores[:, _eos_token_id] = -1e12
|
||||
# The first token generated.
|
||||
next_tokens = scores.argmax(dim=-1, keepdim=True)
|
||||
print(f"Input tokens: {tokens}")
|
||||
_generated_idx = 0
|
||||
print(f"gt - {_generated_idx} | ", next_tokens); _generated_idx += 1
|
||||
token_ids = torch.cat([tokens, next_tokens], dim=1)
|
||||
if streaming:
|
||||
yield token_ids
|
||||
cur_len = token_ids.size(1)
|
||||
dones = token_ids.new_zeros(batch_size).eq(1)
|
||||
# tokens = tokens[:, -1:]
|
||||
|
@ -417,7 +432,8 @@ def _no_beam_search_generate(
|
|||
lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
|
||||
)
|
||||
scores.scatter_(dim=1, index=token_ids, src=token_scores)
|
||||
|
||||
# scores: [bsz, vocab_size]
|
||||
import pdb; pdb.set_trace()
|
||||
if eos_token_id is not None and length_penalty != 1.0:
|
||||
# batch_size x vocab_size
|
||||
token_scores = scores / cur_len**length_penalty
|
||||
|
@ -444,7 +460,11 @@ def _no_beam_search_generate(
|
|||
next_tokens = next_tokens.masked_fill(dones, pad_token_id)
|
||||
tokens = next_tokens.unsqueeze(1)
|
||||
|
||||
print(f"gt - {_generated_idx} | ", next_tokens); _generated_idx += 1
|
||||
|
||||
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len
|
||||
if streaming:
|
||||
yield token_ids
|
||||
|
||||
end_mask = next_tokens.eq(_eos_token_id)
|
||||
dones = dones.__or__(end_mask)
|
||||
|
@ -461,7 +481,8 @@ def _no_beam_search_generate(
|
|||
# token_ids[:, -1].masked_fill_(~dones, eos_token_id)
|
||||
# TODO Here we are simply adding an extra dimension for interface compatibility, but in the future it will need to
|
||||
# be able to return multiple real results
|
||||
return token_ids[:, None]
|
||||
if not streaming:
|
||||
return token_ids[:, None]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
Loading…
Reference in New Issue