InternLM/internlm/apis/inference.py

969 lines
37 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import List, Tuple
import torch
import torch.nn.functional as F
from torch import nn
__all__ = ["SequenceGenerator"]
class InferenceParams:
"""
Intermediate cache objects for inference
"""
def __init__(
self,
max_sequence_len,
max_batch_size,
sequence_len_offset=0,
batch_size_offset=0,
key_value_memory_dict: dict = None,
lengths_per_sample=None,
attention_mask=None,
) -> None:
self.max_sequence_len: int = max_sequence_len
self.max_batch_size: int = max_batch_size
self.sequence_len_offset: int = sequence_len_offset
self.batch_size_offset: int = batch_size_offset
if key_value_memory_dict is None:
key_value_memory_dict = {}
self.key_value_memory_dict: dict = key_value_memory_dict
self.fused_ft_kernel: bool = False
self.lengths_per_sample = lengths_per_sample
self.attention_mask = attention_mask
def reorder_state(self, indices):
if self.lengths_per_sample is not None:
self.lengths_per_sample = self.lengths_per_sample.index_select(index=indices, dim=0)
for key, value in list(self.key_value_memory_dict.items()):
value = value.index_select(index=indices, dim=0)
self.key_value_memory_dict[key] = value
def _get_model_device(model):
"""
obtain the device of an nn.Module.model
Args:
model: nn.Module
Return: torch.device. if None, the parameters of this model is None.
"""
assert isinstance(model, nn.Module)
parameters = list(model.parameters())
if len(parameters) == 0:
return None
else:
return parameters[0].device
class SequenceGenerator:
"""
Sequence Generator.
"""
def __init__(
self,
decoder,
eos_token_id,
pad_token_id,
bos_token_id,
additional_eos_token_list=None,
add_eos_when_return=False,
):
self.decoder = decoder
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.additional_eos_token_list = additional_eos_token_list
self.add_eos_when_return = add_eos_when_return
@torch.no_grad()
def generate(
self,
tokens: "torch.LongTensor" = None,
num_return_sequences=1,
max_length: int = 20,
num_beams: int = 1,
do_sample: bool = True,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 1.0,
repetition_penalty: float = 1,
length_penalty: float = 1.0,
):
"""
Args:
tokens: the beginning tokens whose shape is [bsz, length]. If shape is None, default ''bos_token'' will be
added to conduct generation.
num_return_sequences: number of returned sequences.
max_length: the max length of generated sequence.
num_beams: the size of beam search.
do_sample: whether using sample.
temperature: it's meaningful when do_sample is True.
top_k: sampling from top_k.
top_p: sampling from top_p tokens(nucleus sampling).
Return:
the token sequence whose shape is [bsz, num_return_sequences, max_length]. If eos_token_id is not None,
the ending of each sequence must be eos_token_id.
"""
assert num_return_sequences <= num_beams, f"The `{num_return_sequences}` must be less than `{num_beams}`..."
if do_sample:
return sample_generate(
self.decoder,
tokens=tokens,
max_length=max_length,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
temperature=temperature,
top_k=top_k,
top_p=top_p,
eos_token_id=self.eos_token_id, # the ending token id
additional_eos_token_list=self.additional_eos_token_list,
add_eos_when_return=self.add_eos_when_return,
pad_token_id=self.pad_token_id,
repetition_penalty=repetition_penalty, # the penalty degree for repetition tokens
length_penalty=length_penalty, # the penalty for length. if it > 1, then encourages long sequence.
# Otherwise, encourages short sequence.
bos_token_id=self.bos_token_id,
)
else:
return greedy_generate(
self.decoder,
tokens=tokens,
max_length=max_length,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
eos_token_id=self.eos_token_id,
additional_eos_token_list=self.additional_eos_token_list,
add_eos_when_return=self.add_eos_when_return,
pad_token_id=self.pad_token_id,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
bos_token_id=self.bos_token_id,
)
@torch.no_grad()
def streaming_generate(
self,
tokens: "torch.LongTensor" = None,
max_length: int = 20,
do_sample: bool = True,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 1.0,
repetition_penalty: float = 1,
length_penalty: float = 1.0,
):
if not do_sample:
temperature = 1
top_k = 50
top_p = 1
yield from _streaming_no_beam_search_generate(
self.decoder,
tokens=tokens,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
eos_token_id=self.eos_token_id,
additional_eos_token_list=self.additional_eos_token_list,
add_eos_when_return=self.add_eos_when_return,
do_sample=do_sample,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
pad_token_id=self.pad_token_id,
bos_token_id=self.bos_token_id,
)
@torch.no_grad()
def greedy_generate(
decoder,
tokens=None,
max_length=20,
num_beams=1,
num_return_sequences=1,
eos_token_id=None,
additional_eos_token_list=None,
add_eos_when_return=False,
pad_token_id=0,
repetition_penalty=1,
length_penalty=1.0,
bos_token_id=1,
):
"""
Search sequence greedily.
Args:
decoder: the Decoder object.
tokens: the shape is [batch size, length]. If decoder is None, generating begins with bos_token_id.
max_length: the max length for generated sequence.
num_beams: the size of beam to decode.
eos_token_id: the ending token id. If None, the decode length is max_length.
pad_token_id: the token id of pad.
repetition_penalty: the penalty degree for repetition tokens
length_penalty: the penalty for length.
"""
if num_beams == 1:
token_ids = _no_beam_search_generate(
decoder,
tokens=tokens,
max_length=max_length,
temperature=1,
top_k=50,
top_p=1,
eos_token_id=eos_token_id,
additional_eos_token_list=additional_eos_token_list,
add_eos_when_return=add_eos_when_return,
do_sample=False,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
)
else:
token_ids = _beam_search_generate(
decoder,
tokens=tokens,
max_length=max_length,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
temperature=1,
top_k=50,
top_p=1,
eos_token_id=eos_token_id,
additional_eos_token_list=additional_eos_token_list,
add_eos_when_return=add_eos_when_return,
do_sample=False,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
)
return token_ids
@torch.no_grad()
def sample_generate(
decoder,
tokens,
max_length=20,
num_beams=1,
num_return_sequences=1,
temperature=1.0,
top_k=50,
top_p=1.0,
eos_token_id=None,
additional_eos_token_list=None,
add_eos_when_return=False,
pad_token_id=0,
repetition_penalty=1.0,
length_penalty=1.0,
bos_token_id=1,
):
"""
generate sequence in sampling way.
Args:
decoder: the Decoder object.
tokens: the shape is [batch size, length]. If decoder is None, generating begins with bos_token_id.
max_length: the max length for generated sequence.
num_beams: the size of beam to decode.
num_return_sequences: number of returned sequence.
temperature: annealing magnitude during sampling.
top_k: sampling from top_k. (Default: 50)
top_p: sampling from top_p tokens(nucleus sampling). (Default: 1.0)
eos_token_id: the ending token id. If None, the decode length is max_length.
pad_token_id: the token id of pad.
repetition_penalty: the penalty degree for repetition tokens
length_penalty: the penalty for length.
"""
if num_beams == 1:
token_ids = _no_beam_search_generate(
decoder,
tokens=tokens,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
eos_token_id=eos_token_id,
additional_eos_token_list=additional_eos_token_list,
add_eos_when_return=add_eos_when_return,
do_sample=True,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
)
else:
token_ids = _beam_search_generate(
decoder,
tokens=tokens,
max_length=max_length,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
temperature=temperature,
top_k=top_k,
top_p=top_p,
eos_token_id=eos_token_id,
additional_eos_token_list=additional_eos_token_list,
add_eos_when_return=add_eos_when_return,
do_sample=True,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
)
return token_ids
@torch.no_grad()
def _streaming_no_beam_search_generate(
decoder,
tokens,
inference_params=None,
max_length=20,
temperature=1.0,
top_k=50,
top_p=1.0,
eos_token_id=None,
additional_eos_token_list=None,
add_eos_when_return=False,
do_sample=True,
repetition_penalty=1.0,
length_penalty=1.0,
pad_token_id=0,
bos_token_id=1,
):
batch_size = tokens.size(0)
if eos_token_id is not None:
if not isinstance(eos_token_id, (List, Tuple)):
eos_token_id = [eos_token_id]
if additional_eos_token_list is not None:
if not isinstance(additional_eos_token_list, (List, Tuple)):
additional_eos_token_list = [additional_eos_token_list]
eos_token_id.extend(additional_eos_token_list)
eos_token_id = torch.LongTensor(eos_token_id).to(tokens.device)
has_bos = torch.all(tokens[:, 0].eq(bos_token_id))
if has_bos:
bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0)
bos_sum = bos_pos.cumsum(dim=-1)
bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
to_atten_x = bos_pos[:, :, None]
to_atten_y = bos_pos[:, None, :]
else:
bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0)
to_atten_x = bos_pos[:, :, None]
to_atten_y = bos_pos[:, None, :]
attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
if inference_params is None:
inference_params = InferenceParams(
max_sequence_len=max_length,
max_batch_size=tokens.size(0),
sequence_len_offset=0,
batch_size_offset=0,
key_value_memory_dict=None,
lengths_per_sample=None,
attention_mask=attention_mask,
)
scores = decoder(**{"input_ids": tokens, "inference_params": inference_params})
if isinstance(scores, (list, tuple)):
scores = scores[0]
scores = scores[:, -1].float()
inference_params.sequence_len_offset += tokens.size(1)
if eos_token_id is not None:
scores[:, eos_token_id] = -1e12
# The first token generated.
next_tokens = scores.argmax(dim=-1, keepdim=True)
token_ids = torch.cat([tokens, next_tokens], dim=1)
yield token_ids
cur_len = token_ids.size(1)
dones = token_ids.new_zeros(batch_size).eq(1)
real_max_length = max_length
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long)
while cur_len < real_max_length:
# batch_size x vocab_size
if has_bos:
bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
bos_sum = bos_pos.cumsum(dim=-1)
bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
to_atten_x = bos_pos[:, :, None]
to_atten_y = bos_pos[:, None, :]
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
else:
bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
to_atten_x = bos_pos[:, :, None]
to_atten_y = bos_pos[:, None, :]
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
inference_params.attention_mask = attention_mask
scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params})
if isinstance(scores, (list, tuple)):
scores = scores[0]
scores = scores[:, -1].float()
inference_params.sequence_len_offset += 1
if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids)
lt_zero_mask = token_scores.lt(0).float()
ge_zero_mask = lt_zero_mask.eq(0).float()
token_scores = (
lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
)
scores.scatter_(dim=1, index=token_ids, src=token_scores)
# scores: [bsz, vocab_size]
if eos_token_id is not None and length_penalty != 1.0:
# batch_size x vocab_size
eos_token_scores = scores[:, eos_token_id].clone()
scores = scores / cur_len**length_penalty
scores[:, eos_token_id] = eos_token_scores
del eos_token_scores
if do_sample:
if temperature > 0 and temperature != 1:
scores = scores / temperature
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2)
# add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523
probs = F.softmax(scores, dim=-1) + 1e-12
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size
else:
next_tokens = torch.argmax(scores, dim=-1) # batch_size
if eos_token_id is not None:
# When the generated result exceeds the length, its eos_token_id is set to the most basic terminator.
next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len + 1), eos_token_id[0])
next_tokens = next_tokens.masked_fill(dones, pad_token_id)
tokens = next_tokens.unsqueeze(1)
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len
yield token_ids
if eos_token_id is not None:
end_mask = torch.any(next_tokens[:, None].eq(eos_token_id), dim=-1)
dones = dones.__or__(end_mask)
cur_len += 1
if dones.min() == 1:
break
# token_ids: [bsz, seqlen]
if eos_token_id is not None and add_eos_when_return:
token_ids = torch.cat([token_ids, token_ids.new_full((token_ids.size(0), 1), eos_token_id[0])], dim=1)
yield token_ids
@torch.no_grad()
def _no_beam_search_generate(
decoder,
tokens,
inference_params=None,
max_length=20,
temperature=1.0,
top_k=50,
top_p=1.0,
eos_token_id=None,
additional_eos_token_list=None,
add_eos_when_return=False,
do_sample=True,
repetition_penalty=1.0,
length_penalty=1.0,
pad_token_id=0,
bos_token_id=1,
):
batch_size = tokens.size(0)
if eos_token_id is not None:
if not isinstance(eos_token_id, (List, Tuple)):
eos_token_id = [eos_token_id]
if additional_eos_token_list is not None:
if not isinstance(additional_eos_token_list, (List, Tuple)):
additional_eos_token_list = [additional_eos_token_list]
eos_token_id.extend(additional_eos_token_list)
eos_token_id = torch.LongTensor(eos_token_id).to(tokens.device)
has_bos = torch.all(tokens[:, 0].eq(bos_token_id))
if has_bos:
bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0)
bos_sum = bos_pos.cumsum(dim=-1)
bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
to_atten_x = bos_pos[:, :, None]
to_atten_y = bos_pos[:, None, :]
else:
bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0)
to_atten_x = bos_pos[:, :, None]
to_atten_y = bos_pos[:, None, :]
attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
if inference_params is None:
inference_params = InferenceParams(
max_sequence_len=max_length,
max_batch_size=tokens.size(0),
sequence_len_offset=0,
batch_size_offset=0,
key_value_memory_dict=None,
lengths_per_sample=None,
attention_mask=attention_mask,
)
scores = decoder(**{"input_ids": tokens, "inference_params": inference_params})
if isinstance(scores, (list, tuple)):
scores = scores[0]
scores = scores[:, -1].float()
inference_params.sequence_len_offset += tokens.size(1)
if eos_token_id is not None:
scores[:, eos_token_id] = -1e12
# The first token generated.
next_tokens = scores.argmax(dim=-1, keepdim=True)
token_ids = torch.cat([tokens, next_tokens], dim=1)
cur_len = token_ids.size(1)
dones = token_ids.new_zeros(batch_size).eq(1)
real_max_length = max_length
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long)
while cur_len < real_max_length:
# batch_size x vocab_size
if has_bos:
bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
bos_sum = bos_pos.cumsum(dim=-1)
bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
to_atten_x = bos_pos[:, :, None]
to_atten_y = bos_pos[:, None, :]
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
else:
bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
to_atten_x = bos_pos[:, :, None]
to_atten_y = bos_pos[:, None, :]
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
inference_params.attention_mask = attention_mask
scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params})
if isinstance(scores, (list, tuple)):
scores = scores[0]
scores = scores[:, -1].float()
inference_params.sequence_len_offset += 1
if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids)
lt_zero_mask = token_scores.lt(0).float()
ge_zero_mask = lt_zero_mask.eq(0).float()
token_scores = (
lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
)
scores.scatter_(dim=1, index=token_ids, src=token_scores)
# scores: [bsz, vocab_size]
if eos_token_id is not None and length_penalty != 1.0:
# batch_size x vocab_size
eos_token_scores = scores[:, eos_token_id].clone()
scores = scores / cur_len**length_penalty
scores[:, eos_token_id] = eos_token_scores
del eos_token_scores
if do_sample:
if temperature > 0 and temperature != 1:
scores = scores / temperature
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2)
# add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523
probs = F.softmax(scores, dim=-1) + 1e-12
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size
else:
next_tokens = torch.argmax(scores, dim=-1) # batch_size
if eos_token_id is not None:
# When the generated result exceeds the length, its eos_token_id is set to the most basic terminator.
next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len + 1), eos_token_id[0])
next_tokens = next_tokens.masked_fill(dones, pad_token_id)
tokens = next_tokens.unsqueeze(1)
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len
if eos_token_id is not None:
end_mask = torch.any(next_tokens[:, None].eq(eos_token_id), dim=-1)
dones = dones.__or__(end_mask)
cur_len += 1
if dones.min() == 1:
break
# token_ids: [bsz, seqlen]
if eos_token_id is not None and add_eos_when_return:
token_ids = torch.cat([token_ids, token_ids.new_full((token_ids.size(0), 1), eos_token_id[0])], dim=1)
# In order to maintain consistency with the results returned by beam search,
# a new dimension is added here representing num_return_sequences.
# token_ids: [bsz, num_return_sequences, seqlen]
return token_ids[:, None]
@torch.no_grad()
def _beam_search_generate(
decoder,
tokens,
inference_params=None,
max_length=20,
num_beams=4,
num_return_sequences=1,
temperature=1.0,
top_k=50,
top_p=1.0,
eos_token_id=None,
additional_eos_token_list=None,
add_eos_when_return=False,
do_sample=True,
repetition_penalty=1.0,
length_penalty=1.0,
pad_token_id=0,
bos_token_id=1,
) -> torch.LongTensor:
device = _get_model_device(decoder)
batch_size = tokens.size(0)
if eos_token_id is not None:
if not isinstance(eos_token_id, (List, Tuple)):
eos_token_id = [eos_token_id]
if additional_eos_token_list is not None:
if not isinstance(additional_eos_token_list, (List, Tuple)):
additional_eos_token_list = [additional_eos_token_list]
eos_token_id.extend(additional_eos_token_list)
eos_token_id = torch.LongTensor(eos_token_id).to(tokens.device)
has_bos = torch.all(tokens[:, 0].eq(bos_token_id))
if has_bos:
bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0)
bos_sum = bos_pos.cumsum(dim=-1)
bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
to_atten_x = bos_pos[:, :, None]
to_atten_y = bos_pos[:, None, :]
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
else:
bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0)
to_atten_x = bos_pos[:, :, None]
to_atten_y = bos_pos[:, None, :]
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
if inference_params is None:
inference_params = InferenceParams(
max_sequence_len=max_length,
max_batch_size=tokens.size(0),
sequence_len_offset=0,
batch_size_offset=0,
key_value_memory_dict=None,
lengths_per_sample=None,
attention_mask=attention_mask,
)
scores = decoder(**{"input_ids": tokens, "inference_params": inference_params})
if isinstance(scores, (list, tuple)):
scores = scores[0]
scores = scores[:, -1].float()
inference_params.sequence_len_offset += tokens.size(1)
if eos_token_id is not None:
scores[:, eos_token_id] = -1e12
vocab_size = scores.size(1)
assert vocab_size >= num_beams, "num_beams should be smaller than " "the number of vocabulary size."
# The first token generated.
if do_sample:
probs = F.softmax(scores, dim=-1) + 1e-12
# (batch_size, num_beams)
next_tokens = torch.multinomial(probs, num_samples=num_beams)
logits = probs.log()
# (batch_size, num_beams)
next_scores = logits.gather(dim=1, index=next_tokens)
else:
scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size)
# obtain (batch_size, num_beams), (batch_size, num_beams)
next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True)
indices = torch.arange(batch_size, dtype=torch.long).to(device)
indices = indices.repeat_interleave(num_beams)
inference_params.reorder_state(indices)
# batch_size * num_beams x length
tokens = tokens.index_select(dim=0, index=indices)
# genrated token (batch_size', cur_len)
token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1)
dones = [False] * batch_size
beam_scores = next_scores.view(-1) # batch_size * num_beams
cur_len = token_ids.size(1)
real_max_length = max_length
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long)
hypos = [
BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
]
# 0, num_beams, 2*num_beams, ...
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids)
while cur_len < real_max_length:
if has_bos:
bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
bos_sum = bos_pos.cumsum(dim=-1)
bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
to_atten_x = bos_pos[:, :, None]
to_atten_y = bos_pos[:, None, :]
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
else:
bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
to_atten_x = bos_pos[:, :, None]
to_atten_y = bos_pos[:, None, :]
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
inference_params.attention_mask = attention_mask
# (bsz x num_beams, vocab_size)
scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params})
if isinstance(scores, (list, tuple)):
scores = scores[0]
scores = scores[:, -1].float()
inference_params.sequence_len_offset += 1
if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids)
lt_zero_mask = token_scores.lt(0).float()
ge_zero_mask = lt_zero_mask.eq(0).float()
token_scores = (
lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
)
scores.scatter_(dim=1, index=token_ids, src=token_scores)
if eos_token_id is not None:
max_len_eos_mask = max_lengths.eq(cur_len + 1)
# When the generated result exceeds the length, its eos_token_id is set to the most basic terminator.
eos_scores = scores[:, eos_token_id[0]]
scores[:, eos_token_id[0]] = torch.where(max_len_eos_mask, eos_scores + 1e32, eos_scores)
if do_sample:
if temperature > 0 and temperature != 1:
scores = scores / temperature
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1)
# add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523
probs = F.softmax(scores, dim=-1) + 1e-12
# batch_size' x (num_beams+1)
_tokens = torch.multinomial(probs, num_samples=num_beams + 1)
logits = probs.log()
# batch_size' x (num_beams+1)
_scores = logits.gather(dim=1, index=_tokens)
# batch_size' x (num_beams+1)
_scores = _scores + beam_scores[:, None]
_scores = _scores.view(batch_size, num_beams * (num_beams + 1))
next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True)
_tokens = _tokens.view(batch_size, num_beams * (num_beams + 1))
# (batch_size, 2*num_beams)
next_tokens = _tokens.gather(dim=1, index=ids)
# (batch_size, 2*num_beams)
from_which_beam = torch.floor(ids.float() / (num_beams + 1)).long()
else:
# (batch_size * num_beams, vocab_size)
scores = F.log_softmax(scores, dim=-1)
# (batch_size * num_beams, vocab_size)
_scores = scores + beam_scores[:, None]
# (batch_size, num_beams*vocab_size)
_scores = _scores.view(batch_size, -1)
# (bsz, 2*num_beams)
next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
# (batch_size, 2*num_beams)
from_which_beam = torch.floor(ids.float() / vocab_size).long()
next_tokens = ids % vocab_size # (batch_size, 2*num_beams)
not_eos_mask = torch.all(next_tokens[..., None].ne(eos_token_id), dim=-1)
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams)
keep_mask = not_eos_mask.__and__(keep_mask)
_next_tokens = next_tokens.masked_select(keep_mask).view(-1, 1)
_from_which_beam = from_which_beam.masked_select(keep_mask).view(batch_size, num_beams)
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams)
beam_scores = _next_scores.view(-1)
flag = True
if cur_len + 1 == real_max_length:
eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0)
eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size)
eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1)
else:
effective_eos_mask = torch.any(
next_tokens[:, :num_beams][..., None].eq(eos_token_id), dim=-1
) # batch_size x num_beams
if effective_eos_mask.sum().gt(0):
eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True)
eos_beam_idx = eos_batch_idx * num_beams * 2 + eos_beam_ind
eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx]
else:
flag = False
if flag:
_token_ids = torch.cat([token_ids, _next_tokens], dim=-1)
for batch_idx, beam_ind, beam_idx in zip(
eos_batch_idx.tolist(), eos_beam_ind.tolist(), eos_beam_idx.tolist()
):
if not dones[batch_idx]:
score = next_scores[batch_idx, beam_ind].item()
if eos_token_id is not None:
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score)
else:
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx].clone(), score)
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1)
inference_params.reorder_state(reorder_inds)
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), _next_tokens], dim=-1)
for batch_idx in range(batch_size):
dones[batch_idx] = (
dones[batch_idx]
or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item())
or max_lengths[batch_idx * num_beams] == cur_len + 1
)
cur_len += 1
if all(dones):
break
# select the best hypotheses
tgt_len = token_ids.new_zeros(batch_size, num_return_sequences)
best = []
for i, hypotheses in enumerate(hypos):
# best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
sorted_hyp = list(sorted(hypotheses.hyp, key=lambda x: x[0], reverse=True))
_best = []
for j, hyp in zip(range(num_return_sequences), sorted_hyp):
hyp = hyp[1]
if eos_token_id is not None and add_eos_when_return:
# When forcing eos to be added at the end of the generated result, use the most basic text terminator.
hyp = torch.cat([hyp, token_ids.new_ones(1) * eos_token_id[0]])
tgt_len[i, j] = len(hyp)
_best.append(hyp)
best.append(_best)
# generate target batch
decoded = token_ids.new_zeros(batch_size, num_return_sequences, tgt_len.max().item()).fill_(pad_token_id)
for i, hypo in enumerate(best):
for j, _hypo in enumerate(hypo):
decoded[i, j, : tgt_len[i, j]] = _hypo
# decoded: [bsz, num_return_sequences, seqlen]
return decoded
class BeamHypotheses(object):
"""
BeamHypotheses
"""
def __init__(self, num_beams, max_length, length_penalty, early_stopping):
"""Initialize n-best list of hypotheses."""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.hyp = []
self.worst_score = 1e9
def __len__(self):
"""Number of hypotheses in the list."""
return len(self.hyp)
def add(self, hyp, sum_logprobs):
"""Add a new hypothesis to the list."""
score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.num_beams or score > self.worst_score:
self.hyp.append((score, hyp))
if len(self) > self.num_beams:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
del self.hyp[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs):
"""If there are enough hypotheses and that none of the hypotheses being
generated can become better than the worst one in the heap, then we are
done with this sentence."""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
return self.worst_score >= best_sum_logprobs / self.max_length**self.length_penalty
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
"""
Based on the values of top_k and top_p, set the values that do not meet the criteria to the filter_value.
Args:
logits: logit value, shape is [bsz, vocab_size].
top_k: If it is greater than 0, only the probabilities of the top_k vocabulary are kept, and the rest of
the positions are set to filter_value.
top_p: according to http://arxiv.org/abs/1904.09751.
filter_value: filter value
min_tokens_to_keep: The probability of words in each samples returned distribution will not be
lower than this value.
"""
if top_k > 0:
# Safety check
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))
# Remove all tokens with a probability less than the last token of
# the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
# (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
# (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits