# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py from typing import List import torch import torch.nn.functional as F _LOGIT_PROCESSOR_MAP = {} def register_logit_processor(process_type): """ register flops computation function for operation. """ def register(func): global _LOGIT_PROCESSOR_MAP _LOGIT_PROCESSOR_MAP[process_type] = func return func return register @register_logit_processor("no_repeat_ngram_size") def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids: List[List[int]]): """ enforces no repetition of n-grams to avoid repetitions of word sequences. """ if not isinstance(ngram_size, int) or ngram_size < 0: raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.") if ngram_size != 0: batch_size = len(batch_token_ids) for batch_id in range(batch_size): current_token_ids = batch_token_ids[batch_id] current_len = len(current_token_ids) if current_len + 1 < ngram_size: continue ngrams_dict = {} for ngram in zip(*[current_token_ids[i:] for i in range(ngram_size)]): prev_ngram_tuple = tuple(ngram[:-1]) ngrams_dict[prev_ngram_tuple] = ngrams_dict.get(prev_ngram_tuple, []) + [ngram[-1]] prev_ngrams = tuple(current_token_ids[current_len + 1 - ngram_size : current_len]) banned_token = ngrams_dict.get(prev_ngrams, []) logits[batch_id, banned_token] = -float("inf") return logits @register_logit_processor("repetition_penalty") def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: List[List[int]]): """ apply the penalty to the tokens present in the prompt. """ if not isinstance(penalty, float) or not (penalty > 0): raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.") logit_list = [] # TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels. if penalty != 1.0: for batch_id in range(len(batch_token_ids)): current_logit = logits[batch_id] current_token = torch.tensor(batch_token_ids[batch_id], dtype=torch.long, device=logits.device) curretn_socre = torch.gather(current_logit, 0, current_token) curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty) logit_list.append(current_logit.scatter(0, current_token, curretn_socre)) logits = torch.stack(logit_list) return logits @register_logit_processor("temperature") def temperature_logit_process(logits, temperature: float): """ apply temperature scaling. """ if not isinstance(temperature, float) or not (0.0 < temperature <= 1.0): except_msg = f"'temperature={temperature}' should be a strictly positive float, less than or equal to 1.0 and greater than 0." if temperature == 0.0: except_msg += "if you want to use greedy decoding strategies, set `do_sample=False`." raise ValueError(except_msg) return logits if temperature == 1.0 else logits / temperature @register_logit_processor("top_k") def top_k_logit_processor(logits, top_k: int): """ top_k logit processor """ if not isinstance(top_k, int) or top_k <= 0: raise ValueError(f"`top_k` should be a strictly positive integer, but got {top_k}.") indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = -float("inf") return logits @register_logit_processor("top_p") def top_p_logit_processor(logits, top_p: float): """ top_p logit processor """ if top_p < 0 or top_p > 1.0: raise ValueError(f"`top_p` should be a float > 0 and < 1, but got {top_p}.") sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove = torch.roll(sorted_indices_to_remove, 1, -1) sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) logits[indices_to_remove] = -float("inf") return logits def logit_processor(processor: str, logits, *args, **kwargs): """ do logit process for given logits. Args: processor(str): the type of logit processor logits(torch.Tensor): input logits Returns: logits after process """ if processor not in _LOGIT_PROCESSOR_MAP: return logits else: func = _LOGIT_PROCESSOR_MAP[processor] logits = func(logits, *args, **kwargs) return logits