#!/usr/bin/env python # -*- encoding: utf-8 -*- import torch import torch.nn.functional as F from torch import nn __all__ = ["SequenceGenerator"] class InferenceParams: """ Intermediate cache objects for inference """ def __init__( self, max_sequence_len, max_batch_size, sequence_len_offset=0, batch_size_offset=0, key_value_memory_dict: dict = None, lengths_per_sample=None, attention_mask=None, ) -> None: self.max_sequence_len: int = max_sequence_len self.max_batch_size: int = max_batch_size self.sequence_len_offset: int = sequence_len_offset self.batch_size_offset: int = batch_size_offset if key_value_memory_dict is None: key_value_memory_dict = {} self.key_value_memory_dict: dict = key_value_memory_dict self.fused_ft_kernel: bool = False self.lengths_per_sample = lengths_per_sample self.attention_mask = attention_mask def reorder_state(self, indices): if self.lengths_per_sample is not None: self.lengths_per_sample = self.lengths_per_sample.index_select(index=indices, dim=0) for key, value in list(self.key_value_memory_dict.items()): value = value.index_select(index=indices, dim=0) self.key_value_memory_dict[key] = value def _get_model_device(model): """ obtain the device of an nn.Module.model Args: model: nn.Module Return: torch.device. if None, the parameters of this model is None. """ assert isinstance(model, nn.Module) parameters = list(model.parameters()) if len(parameters) == 0: return None else: return parameters[0].device class SequenceGenerator: """ Sequence Generator. """ def __init__(self, decoder, eos_token_id, pad_token_id, bos_token_id): self.decoder = decoder self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id @torch.no_grad() def generate( self, tokens: "torch.LongTensor" = None, num_return_sequences=1, max_length: int = 20, num_beams: int = 1, do_sample: bool = True, temperature: float = 1.0, top_k: int = 50, top_p: float = 1.0, repetition_penalty: float = 1, length_penalty: float = 1.0, ): """ Args: tokens: the beginning tokens whose shape is [bsz, length]. If shape is None, default ''bos_token'' will be added to conduct generation. num_return_sequences: number of returned sequences. max_length: the max length of generated sequence. num_beams: the size of beam search. do_sample: whether using sample. temperature: it's meaningful when do_sample is True. top_k: sampling from top_k. top_p: sampling from top_p tokens(nucleus sampling). Return: the token sequence whose shape is [bsz, num_return_sequences, max_length]. If eos_token_id is not None, the ending of each sequence must be eos_token_id. """ assert num_return_sequences <= num_beams, f"The `{num_return_sequences}` must be less than `{num_beams}`..." if do_sample: return sample_generate( self.decoder, tokens=tokens, max_length=max_length, num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=temperature, top_k=top_k, top_p=top_p, eos_token_id=self.eos_token_id, # the ending token id pad_token_id=self.pad_token_id, repetition_penalty=repetition_penalty, # the penalty degree for repetition tokens length_penalty=length_penalty, # the penalty for length. if it > 1, then encourages long sequence. # Otherwise, encourages short sequence. bos_token_id=self.bos_token_id, ) else: return greedy_generate( self.decoder, tokens=tokens, max_length=max_length, num_beams=num_beams, num_return_sequences=num_return_sequences, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, repetition_penalty=repetition_penalty, length_penalty=length_penalty, bos_token_id=self.bos_token_id, ) @torch.no_grad() def greedy_generate( decoder, tokens=None, max_length=20, num_beams=1, num_return_sequences=1, eos_token_id=None, pad_token_id=0, repetition_penalty=1, length_penalty=1.0, bos_token_id=1, feat_mask=None, ffn_mask=None, layer_mask=None, ): """ Search sequence greedily. Args: decoder: the Decoder object. tokens: the shape is [batch size, length]. If decoder is None, generating begins with bos_token_id. max_length: the max length for generated sequence. num_beams: the size of beam to decode. eos_token_id: the ending token id. If None, the decode length is max_length. pad_token_id: the token id of pad. repetition_penalty: the penalty degree for repetition tokens length_penalty: the penalty for length. """ if num_beams == 1: token_ids = _no_beam_search_generate( decoder, tokens=tokens, max_length=max_length, temperature=1, top_k=50, top_p=1, eos_token_id=eos_token_id, do_sample=False, repetition_penalty=repetition_penalty, length_penalty=length_penalty, pad_token_id=pad_token_id, bos_token_id=bos_token_id, feat_mask=feat_mask, ffn_mask=ffn_mask, layer_mask=layer_mask, ) else: token_ids = _beam_search_generate( decoder, tokens=tokens, max_length=max_length, num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=1, top_k=50, top_p=1, eos_token_id=eos_token_id, do_sample=False, repetition_penalty=repetition_penalty, length_penalty=length_penalty, pad_token_id=pad_token_id, bos_token_id=bos_token_id, feat_mask=feat_mask, ffn_mask=ffn_mask, layer_mask=layer_mask, ) return token_ids @torch.no_grad() def sample_generate( decoder, tokens, max_length=20, num_beams=1, num_return_sequences=1, temperature=1.0, top_k=50, top_p=1.0, eos_token_id=None, pad_token_id=0, repetition_penalty=1.0, length_penalty=1.0, bos_token_id=1, ): """ generate sequence in sampling way. Args: decoder: the Decoder object. tokens: the shape is [batch size, length]. If decoder is None, generating begins with bos_token_id. max_length: the max length for generated sequence. num_beams: the size of beam to decode. num_return_sequences: number of returned sequence. temperature: annealing magnitude during sampling. top_k: sampling from top_k. (Default: 50) top_p: sampling from top_p tokens(nucleus sampling). (Default: 1.0) eos_token_id: the ending token id. If None, the decode length is max_length. pad_token_id: the token id of pad. repetition_penalty: the penalty degree for repetition tokens length_penalty: the penalty for length. """ if num_beams == 1: token_ids = _no_beam_search_generate( decoder, tokens=tokens, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, eos_token_id=eos_token_id, do_sample=True, repetition_penalty=repetition_penalty, length_penalty=length_penalty, pad_token_id=pad_token_id, bos_token_id=bos_token_id, ) else: token_ids = _beam_search_generate( decoder, tokens=tokens, max_length=max_length, num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=temperature, top_k=top_k, top_p=top_p, eos_token_id=eos_token_id, do_sample=True, repetition_penalty=repetition_penalty, length_penalty=length_penalty, pad_token_id=pad_token_id, bos_token_id=bos_token_id, ) return token_ids @torch.no_grad() def _no_beam_search_generate( decoder, tokens, inference_params=None, max_length=20, temperature=1.0, top_k=50, top_p=1.0, eos_token_id=None, do_sample=True, repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0, bos_token_id=1, feat_mask=None, ffn_mask=None, layer_mask=None, ): # delete num_return_sequences=1 for lint check; batch_size = tokens.size(0) if eos_token_id is None: _eos_token_id = -1 else: _eos_token_id = eos_token_id has_bos = torch.all(tokens[:, 0].eq(bos_token_id)) if has_bos: bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) bos_sum = bos_pos.cumsum(dim=-1) bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) to_atten_x = bos_pos[:, :, None] to_atten_y = bos_pos[:, None, :] # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) else: bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) to_atten_x = bos_pos[:, :, None] to_atten_y = bos_pos[:, None, :] # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) if inference_params is None: inference_params = InferenceParams( max_sequence_len=max_length, max_batch_size=tokens.size(0), sequence_len_offset=0, batch_size_offset=0, key_value_memory_dict=None, lengths_per_sample=None, attention_mask=attention_mask, ) if layer_mask is None: if feat_mask is None and ffn_mask is None: scores = decoder(**{"input_ids": tokens, "inference_params": inference_params}) else: scores = decoder( **{ "input_ids": tokens, "inference_params": inference_params, "feat_mask": feat_mask, "ffn_mask": ffn_mask, } ) else: scores = decoder( **{ "input_ids": tokens, "inference_params": inference_params, "feat_mask": feat_mask, "ffn_mask": ffn_mask, "layer_mask": layer_mask, } ) if isinstance(scores, (list, tuple)): scores = scores[0] scores = scores[:, -1].float() inference_params.sequence_len_offset += tokens.size(1) if _eos_token_id != -1: scores[:, _eos_token_id] = -1e12 next_tokens = scores.argmax(dim=-1, keepdim=True) token_ids = torch.cat([tokens, next_tokens], dim=1) cur_len = token_ids.size(1) dones = token_ids.new_zeros(batch_size).eq(1) # tokens = tokens[:, -1:] real_max_length = max_length max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) while cur_len < real_max_length: # batch_size x vocab_size if has_bos: bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0) bos_sum = bos_pos.cumsum(dim=-1) bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) to_atten_x = bos_pos[:, :, None] to_atten_y = bos_pos[:, None, :] # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) else: bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0) to_atten_x = bos_pos[:, :, None] to_atten_y = bos_pos[:, None, :] # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) inference_params.attention_mask = attention_mask if layer_mask is None: if feat_mask is None and ffn_mask is None: scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params}) else: scores = decoder( **{ "input_ids": token_ids[:, -1:], "inference_params": inference_params, "feat_mask": feat_mask, "ffn_mask": ffn_mask, } ) else: scores = decoder( **{ "input_ids": token_ids[:, -1:], "inference_params": inference_params, "feat_mask": feat_mask, "ffn_mask": ffn_mask, "layer_mask": layer_mask, } ) if isinstance(scores, (list, tuple)): scores = scores[0] scores = scores[:, -1].float() inference_params.sequence_len_offset += 1 if repetition_penalty != 1.0: token_scores = scores.gather(dim=1, index=token_ids) lt_zero_mask = token_scores.lt(0).float() ge_zero_mask = lt_zero_mask.eq(0).float() token_scores = ( lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores ) scores.scatter_(dim=1, index=token_ids, src=token_scores) if eos_token_id is not None and length_penalty != 1.0: # batch_size x vocab_size token_scores = scores / cur_len**length_penalty eos_mask = scores.new_ones(scores.size(1)) eos_mask[eos_token_id] = 0 eos_mask = eos_mask.unsqueeze(0).eq(1) scores = scores.masked_scatter(eos_mask, token_scores) if do_sample: if temperature > 0 and temperature != 1: scores = scores / temperature scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2) # add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523 probs = F.softmax(scores, dim=-1) + 1e-12 next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size else: next_tokens = torch.argmax(scores, dim=-1) # batch_size if _eos_token_id != -1: next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len + 1), _eos_token_id) next_tokens = next_tokens.masked_fill(dones, pad_token_id) tokens = next_tokens.unsqueeze(1) token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len end_mask = next_tokens.eq(_eos_token_id) dones = dones.__or__(end_mask) cur_len += 1 if dones.min() == 1: break # if eos_token_id is not None: # # setting the eos at the maximum length position # tokens.scatter(index=max_lengths[:, None], dim=1, value=eos_token_id) # if cur_len == max_length: # # If eos is not reached by the maximum length, forcibly replace the last word with eos # token_ids[:, -1].masked_fill_(~dones, eos_token_id) # TODO Here we are simply adding an extra dimension for interface compatibility, but in the future it will need to # be able to return multiple real results return token_ids[:, None] @torch.no_grad() def _beam_search_generate( decoder, tokens, inference_params=None, max_length=20, num_beams=4, num_return_sequences=1, temperature=1.0, top_k=50, top_p=1.0, eos_token_id=None, do_sample=True, repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0, bos_token_id=1, feat_mask=None, ffn_mask=None, layer_mask=None, ) -> torch.LongTensor: device = _get_model_device(decoder) batch_size = tokens.size(0) if eos_token_id is None: _eos_token_id = -1 else: _eos_token_id = eos_token_id has_bos = torch.all(tokens[:, 0].eq(bos_token_id)) if has_bos: bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) bos_sum = bos_pos.cumsum(dim=-1) bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) to_atten_x = bos_pos[:, :, None] to_atten_y = bos_pos[:, None, :] # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) else: bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) to_atten_x = bos_pos[:, :, None] to_atten_y = bos_pos[:, None, :] # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) if inference_params is None: inference_params = InferenceParams( max_sequence_len=max_length, max_batch_size=tokens.size(0), sequence_len_offset=0, batch_size_offset=0, key_value_memory_dict=None, lengths_per_sample=None, attention_mask=attention_mask, ) if layer_mask is None: if feat_mask is None and ffn_mask is None: scores = decoder(**{"input_ids": tokens, "inference_params": inference_params}) else: scores = decoder( **{ "input_ids": tokens, "inference_params": inference_params, "feat_mask": feat_mask, "ffn_mask": ffn_mask, } ) else: scores = decoder( **{ "input_ids": tokens, "inference_params": inference_params, "feat_mask": feat_mask, "ffn_mask": ffn_mask, "layer_mask": layer_mask, } ) if isinstance(scores, (list, tuple)): scores = scores[0] scores = scores[:, -1].float() inference_params.sequence_len_offset += tokens.size(1) if _eos_token_id != -1: scores[:, _eos_token_id] = -1e12 vocab_size = scores.size(1) assert vocab_size >= num_beams, "num_beams should be smaller than " "the number of vocabulary size." if do_sample: probs = F.softmax(scores, dim=-1) + 1e-12 # (batch_size, num_beams) next_tokens = torch.multinomial(probs, num_samples=num_beams) logits = probs.log() # (batch_size, num_beams) next_scores = logits.gather(dim=1, index=next_tokens) else: scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size) # obtain (batch_size, num_beams), (batch_size, num_beams) next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) indices = torch.arange(batch_size, dtype=torch.long).to(device) indices = indices.repeat_interleave(num_beams) inference_params.reorder_state(indices) # batch_size * num_beams x length tokens = tokens.index_select(dim=0, index=indices) # genrated token (batch_size', cur_len) token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1) dones = [False] * batch_size beam_scores = next_scores.view(-1) # batch_size * num_beams cur_len = token_ids.size(1) real_max_length = max_length max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) hypos = [ BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size) ] # 0, num_beams, 2*num_beams, ... batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) while cur_len < real_max_length: if has_bos: bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0) bos_sum = bos_pos.cumsum(dim=-1) bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) to_atten_x = bos_pos[:, :, None] to_atten_y = bos_pos[:, None, :] # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) else: bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0) to_atten_x = bos_pos[:, :, None] to_atten_y = bos_pos[:, None, :] # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) inference_params.attention_mask = attention_mask # (bsz x num_beams, vocab_size) if layer_mask is None: if feat_mask is None and ffn_mask is None: scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params}) else: scores = decoder( **{ "input_ids": token_ids[:, -1:], "inference_params": inference_params, "feat_mask": feat_mask, "ffn_mask": ffn_mask, } ) else: scores = decoder( **{ "input_ids": token_ids[:, -1:], "inference_params": inference_params, "feat_mask": feat_mask, "ffn_mask": ffn_mask, "layer_mask": layer_mask, } ) if isinstance(scores, (list, tuple)): scores = scores[0] scores = scores[:, -1].float() inference_params.sequence_len_offset += 1 if repetition_penalty != 1.0: token_scores = scores.gather(dim=1, index=token_ids) lt_zero_mask = token_scores.lt(0).float() ge_zero_mask = lt_zero_mask.eq(0).float() token_scores = ( lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores ) scores.scatter_(dim=1, index=token_ids, src=token_scores) if _eos_token_id != -1: max_len_eos_mask = max_lengths.eq(cur_len + 1) eos_scores = scores[:, _eos_token_id] scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores + 1e32, eos_scores) if do_sample: if temperature > 0 and temperature != 1: scores = scores / temperature scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1) # add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523 probs = F.softmax(scores, dim=-1) + 1e-12 # batch_size' x (num_beams+1) _tokens = torch.multinomial(probs, num_samples=num_beams + 1) logits = probs.log() # batch_size' x (num_beams+1) _scores = logits.gather(dim=1, index=_tokens) # batch_size' x (num_beams+1) _scores = _scores + beam_scores[:, None] _scores = _scores.view(batch_size, num_beams * (num_beams + 1)) next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True) _tokens = _tokens.view(batch_size, num_beams * (num_beams + 1)) # (batch_size, 2*num_beams) next_tokens = _tokens.gather(dim=1, index=ids) # (batch_size, 2*num_beams) from_which_beam = torch.floor(ids.float() / (num_beams + 1)).long() else: # (batch_size * num_beams, vocab_size) scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) _scores = scores + beam_scores[:, None] # (batch_size, num_beams*vocab_size) _scores = _scores.view(batch_size, -1) # (bsz, 2*num_beams) next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (batch_size, 2*num_beams) from_which_beam = torch.floor(ids.float() / vocab_size).long() next_tokens = ids % vocab_size # (batch_size, 2*num_beams) # next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True) # next_tokens = next_tokens.gather(dim=1, index=sorted_inds) # from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds) not_eos_mask = next_tokens.ne(_eos_token_id) keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) keep_mask = not_eos_mask.__and__(keep_mask) _next_tokens = next_tokens.masked_select(keep_mask).view(-1, 1) _from_which_beam = from_which_beam.masked_select(keep_mask).view(batch_size, num_beams) _next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams) beam_scores = _next_scores.view(-1) flag = True if cur_len + 1 == real_max_length: eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0) eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) else: effective_eos_mask = next_tokens[:, :num_beams].eq(_eos_token_id) # batch_size x num_beams if effective_eos_mask.sum().gt(0): eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True) eos_beam_idx = eos_batch_idx * num_beams * 2 + eos_beam_ind eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx] else: flag = False if flag: _token_ids = torch.cat([token_ids, _next_tokens], dim=-1) for batch_idx, beam_ind, beam_idx in zip( eos_batch_idx.tolist(), eos_beam_ind.tolist(), eos_beam_idx.tolist() ): if not dones[batch_idx]: score = next_scores[batch_idx, beam_ind].item() if _eos_token_id != -1: hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) else: hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx].clone(), score) reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) inference_params.reorder_state(reorder_inds) token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), _next_tokens], dim=-1) for batch_idx in range(batch_size): dones[batch_idx] = ( dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) or max_lengths[batch_idx * num_beams] == cur_len + 1 ) cur_len += 1 if all(dones): break # select the best hypotheses tgt_len = token_ids.new_zeros(batch_size, num_return_sequences) best = [] for i, hypotheses in enumerate(hypos): # best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] sorted_hyp = list(sorted(hypotheses.hyp, key=lambda x: x[0], reverse=True)) _best = [] for j, hyp in zip(range(num_return_sequences), sorted_hyp): hyp = hyp[1] if _eos_token_id != -1: hyp = torch.cat([hyp, token_ids.new_ones(1) * _eos_token_id]) tgt_len[i, j] = len(hyp) _best.append(hyp) best.append(_best) # generate target batch decoded = token_ids.new_zeros(batch_size, num_return_sequences, tgt_len.max().item()).fill_(pad_token_id) for i, hypo in enumerate(best): for j, _hypo in enumerate(hypo): decoded[i, j, : tgt_len[i, j]] = _hypo return decoded class BeamHypotheses(object): """ BeamHypotheses """ def __init__(self, num_beams, max_length, length_penalty, early_stopping): """Initialize n-best list of hypotheses.""" self.max_length = max_length - 1 # ignoring bos_token self.length_penalty = length_penalty self.early_stopping = early_stopping self.num_beams = num_beams self.hyp = [] self.worst_score = 1e9 def __len__(self): """Number of hypotheses in the list.""" return len(self.hyp) def add(self, hyp, sum_logprobs): """Add a new hypothesis to the list.""" score = sum_logprobs / len(hyp) ** self.length_penalty if len(self) < self.num_beams or score > self.worst_score: self.hyp.append((score, hyp)) if len(self) > self.num_beams: sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)]) del self.hyp[sorted_scores[0][1]] self.worst_score = sorted_scores[1][0] else: self.worst_score = min(score, self.worst_score) def is_done(self, best_sum_logprobs): """If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst one in the heap, then we are done with this sentence.""" if len(self) < self.num_beams: return False elif self.early_stopping: return True else: return self.worst_score >= best_sum_logprobs / self.max_length**self.length_penalty def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): """ Based on the values of top_k and top_p, set the values that do not meet the criteria to the filter_value. Args: logits: logit value, shape is [bsz, vocab_size]. top_k: If it is greater than 0, only the probabilities of the top_k vocabulary are kept, and the rest of the positions are set to filter_value. top_p: according to http://arxiv.org/abs/1904.09751. filter_value: filter value min_tokens_to_keep: The probability of words in each sampleā€˜s returned distribution will not be lower than this value. """ if top_k > 0: # Safety check top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Remove all tokens with a probability less than the last token of # the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold # (token with 0 are kept) sorted_indices_to_remove = cumulative_probs > top_p if min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep # (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 # Shift the indices to the right to keep also the first token # above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = filter_value return logits