mirror of https://github.com/InternLM/InternLM
debug for web_demo_internlm
parent
b9c813a972
commit
94fdd178ba
|
@ -85,6 +85,7 @@ class SequenceGenerator:
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
repetition_penalty: float = 1,
|
repetition_penalty: float = 1,
|
||||||
length_penalty: float = 1.0,
|
length_penalty: float = 1.0,
|
||||||
|
streaming=False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -119,6 +120,7 @@ class SequenceGenerator:
|
||||||
length_penalty=length_penalty, # the penalty for length. if it > 1, then encourages long sequence.
|
length_penalty=length_penalty, # the penalty for length. if it > 1, then encourages long sequence.
|
||||||
# Otherwise, encourages short sequence.
|
# Otherwise, encourages short sequence.
|
||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
|
streaming=streaming
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return greedy_generate(
|
return greedy_generate(
|
||||||
|
@ -132,6 +134,7 @@ class SequenceGenerator:
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
|
streaming=streaming
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -147,6 +150,7 @@ def greedy_generate(
|
||||||
repetition_penalty=1,
|
repetition_penalty=1,
|
||||||
length_penalty=1.0,
|
length_penalty=1.0,
|
||||||
bos_token_id=1,
|
bos_token_id=1,
|
||||||
|
streaming=False,
|
||||||
feat_mask=None,
|
feat_mask=None,
|
||||||
ffn_mask=None,
|
ffn_mask=None,
|
||||||
layer_mask=None,
|
layer_mask=None,
|
||||||
|
@ -179,6 +183,7 @@ def greedy_generate(
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
|
streaming=streaming,
|
||||||
feat_mask=feat_mask,
|
feat_mask=feat_mask,
|
||||||
ffn_mask=ffn_mask,
|
ffn_mask=ffn_mask,
|
||||||
layer_mask=layer_mask,
|
layer_mask=layer_mask,
|
||||||
|
@ -199,6 +204,7 @@ def greedy_generate(
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
|
# streaming=streaming,
|
||||||
feat_mask=feat_mask,
|
feat_mask=feat_mask,
|
||||||
ffn_mask=ffn_mask,
|
ffn_mask=ffn_mask,
|
||||||
layer_mask=layer_mask,
|
layer_mask=layer_mask,
|
||||||
|
@ -222,6 +228,7 @@ def sample_generate(
|
||||||
repetition_penalty=1.0,
|
repetition_penalty=1.0,
|
||||||
length_penalty=1.0,
|
length_penalty=1.0,
|
||||||
bos_token_id=1,
|
bos_token_id=1,
|
||||||
|
streaming=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
generate sequence in sampling way.
|
generate sequence in sampling way.
|
||||||
|
@ -255,6 +262,7 @@ def sample_generate(
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
token_ids = _beam_search_generate(
|
token_ids = _beam_search_generate(
|
||||||
|
@ -272,6 +280,7 @@ def sample_generate(
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
|
# streaming=streaming,
|
||||||
)
|
)
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
|
@ -291,11 +300,11 @@ def _no_beam_search_generate(
|
||||||
length_penalty=1.0,
|
length_penalty=1.0,
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
bos_token_id=1,
|
bos_token_id=1,
|
||||||
|
streaming=False,
|
||||||
feat_mask=None,
|
feat_mask=None,
|
||||||
ffn_mask=None,
|
ffn_mask=None,
|
||||||
layer_mask=None,
|
layer_mask=None,
|
||||||
):
|
):
|
||||||
# delete num_return_sequences=1 for lint check;
|
|
||||||
batch_size = tokens.size(0)
|
batch_size = tokens.size(0)
|
||||||
if eos_token_id is None:
|
if eos_token_id is None:
|
||||||
_eos_token_id = -1
|
_eos_token_id = -1
|
||||||
|
@ -356,8 +365,14 @@ def _no_beam_search_generate(
|
||||||
inference_params.sequence_len_offset += tokens.size(1)
|
inference_params.sequence_len_offset += tokens.size(1)
|
||||||
if _eos_token_id != -1:
|
if _eos_token_id != -1:
|
||||||
scores[:, _eos_token_id] = -1e12
|
scores[:, _eos_token_id] = -1e12
|
||||||
|
# The first token generated.
|
||||||
next_tokens = scores.argmax(dim=-1, keepdim=True)
|
next_tokens = scores.argmax(dim=-1, keepdim=True)
|
||||||
|
print(f"Input tokens: {tokens}")
|
||||||
|
_generated_idx = 0
|
||||||
|
print(f"gt - {_generated_idx} | ", next_tokens); _generated_idx += 1
|
||||||
token_ids = torch.cat([tokens, next_tokens], dim=1)
|
token_ids = torch.cat([tokens, next_tokens], dim=1)
|
||||||
|
if streaming:
|
||||||
|
yield token_ids
|
||||||
cur_len = token_ids.size(1)
|
cur_len = token_ids.size(1)
|
||||||
dones = token_ids.new_zeros(batch_size).eq(1)
|
dones = token_ids.new_zeros(batch_size).eq(1)
|
||||||
# tokens = tokens[:, -1:]
|
# tokens = tokens[:, -1:]
|
||||||
|
@ -417,7 +432,8 @@ def _no_beam_search_generate(
|
||||||
lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
|
lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
|
||||||
)
|
)
|
||||||
scores.scatter_(dim=1, index=token_ids, src=token_scores)
|
scores.scatter_(dim=1, index=token_ids, src=token_scores)
|
||||||
|
# scores: [bsz, vocab_size]
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
if eos_token_id is not None and length_penalty != 1.0:
|
if eos_token_id is not None and length_penalty != 1.0:
|
||||||
# batch_size x vocab_size
|
# batch_size x vocab_size
|
||||||
token_scores = scores / cur_len**length_penalty
|
token_scores = scores / cur_len**length_penalty
|
||||||
|
@ -444,7 +460,11 @@ def _no_beam_search_generate(
|
||||||
next_tokens = next_tokens.masked_fill(dones, pad_token_id)
|
next_tokens = next_tokens.masked_fill(dones, pad_token_id)
|
||||||
tokens = next_tokens.unsqueeze(1)
|
tokens = next_tokens.unsqueeze(1)
|
||||||
|
|
||||||
|
print(f"gt - {_generated_idx} | ", next_tokens); _generated_idx += 1
|
||||||
|
|
||||||
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len
|
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len
|
||||||
|
if streaming:
|
||||||
|
yield token_ids
|
||||||
|
|
||||||
end_mask = next_tokens.eq(_eos_token_id)
|
end_mask = next_tokens.eq(_eos_token_id)
|
||||||
dones = dones.__or__(end_mask)
|
dones = dones.__or__(end_mask)
|
||||||
|
@ -461,7 +481,8 @@ def _no_beam_search_generate(
|
||||||
# token_ids[:, -1].masked_fill_(~dones, eos_token_id)
|
# token_ids[:, -1].masked_fill_(~dones, eos_token_id)
|
||||||
# TODO Here we are simply adding an extra dimension for interface compatibility, but in the future it will need to
|
# TODO Here we are simply adding an extra dimension for interface compatibility, but in the future it will need to
|
||||||
# be able to return multiple real results
|
# be able to return multiple real results
|
||||||
return token_ids[:, None]
|
if not streaming:
|
||||||
|
return token_ids[:, None]
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
Loading…
Reference in New Issue