ColossalAI/colossalai/inference/logit_processors.py

190 lines
6.5 KiB
Python

# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
import logging
from typing import List, Union
import torch
import torch.nn.functional as F
_LOGITS_PROCESSOR_MAP = {}
def register_logits_processor(process_type):
"""
register flops computation function for operation.
"""
def register(func):
global _LOGITS_PROCESSOR_MAP
_LOGITS_PROCESSOR_MAP[process_type] = func
return func
return register
@register_logits_processor("no_repeat_ngram_size")
def apply_no_repeat_ngram_size(logits, ngram_size: int, batch_token_ids: List[List[int]]):
"""
enforces no repetition of n-grams to avoid repetitions of word sequences.
"""
if not isinstance(ngram_size, int) or ngram_size < 0:
raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.")
if ngram_size != 0:
batch_size = len(batch_token_ids)
for batch_id in range(batch_size):
current_token_ids = batch_token_ids[batch_id]
current_len = len(current_token_ids)
if current_len + 1 < ngram_size:
continue
ngrams_dict = {}
for ngram in zip(*[current_token_ids[i:] for i in range(ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
ngrams_dict[prev_ngram_tuple] = ngrams_dict.get(prev_ngram_tuple, []) + [ngram[-1]]
prev_ngrams = tuple(current_token_ids[current_len + 1 - ngram_size : current_len])
banned_token = ngrams_dict.get(prev_ngrams, [])
logits[batch_id, banned_token] = -float("inf")
return logits
@register_logits_processor("repetition_penalty")
def apply_repetition_penalty(logits, penalty: float, batch_token_ids: List[List[int]]):
"""
apply the penalty to the tokens present in the prompt.
"""
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.")
logits_list = []
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
if penalty != 1.0:
for batch_id in range(len(batch_token_ids)):
current_logit = logits[batch_id]
current_token = torch.tensor(batch_token_ids[batch_id], dtype=torch.long, device=logits.device)
curretn_socre = torch.gather(current_logit, 0, current_token)
curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty)
logits_list.append(current_logit.scatter(0, current_token, curretn_socre))
logits = torch.stack(logits_list)
return logits
@register_logits_processor("temperature")
def apply_temperature(logits, temperature: float):
"""
apply temperature scaling.
"""
if not isinstance(temperature, float) or not (0.0 < temperature <= 1.0):
except_msg = f"'temperature={temperature}' should be a strictly positive float, less than or equal to 1.0 and greater than 0."
if temperature == 0.0:
except_msg += "if you want to use greedy decoding strategies, set `do_sample=False`."
raise ValueError(except_msg)
return logits if temperature == 1.0 else logits / temperature
@register_logits_processor("top_k")
def apply_top_k(logits, top_k: int):
"""
top_k logit processor
"""
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` should be a strictly positive integer, but got {top_k}.")
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = -float("inf")
return logits
@register_logits_processor("top_p")
def apply_top_p(logits, top_p: float):
"""
top_p logit processor
"""
if top_p < 0 or top_p > 1.0:
raise ValueError(f"`top_p` should be a float > 0 and < 1, but got {top_p}.")
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove = torch.roll(sorted_indices_to_remove, 1, -1)
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
logits[indices_to_remove] = -float("inf")
return logits
@register_logits_processor("forced_eos_token_id")
def apply_forced_eos_token_id(
logits: torch.Tensor,
sequence_lengths: Union[torch.Tensor, List[int]],
max_lengths: Union[torch.Tensor, List[int]],
eos_token_id: Union[int, List[int]],
):
"""
Enforces the specified token as the last generated token when the maximum output length
is reached. Notice that the maximum output lengths for different sequences, even if they're
in the same batch, can be different.
Args:
logits(torch.Tensor): logits
sequence_lengths(torch.Tensor): sequence lengths including prompt and output tokens
max_lengths(torch.Tensor): the maximum length for each sequence
eos_token_id(Union[int, List[int]]): forced eos token id
"""
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if isinstance(sequence_lengths, torch.Tensor):
sequence_lengths = sequence_lengths.tolist()
if isinstance(max_lengths, torch.Tensor):
max_lengths = max_lengths.tolist()
select_indexes = []
num_sequences = logits.shape[0]
sequence_lengths = sequence_lengths[:num_sequences]
max_lengths = max_lengths[:num_sequences]
for i, (sequence_length, max_out_length) in enumerate(zip(sequence_lengths, max_lengths)):
if sequence_length == max_out_length - 1:
select_indexes.append(i)
if select_indexes:
logits[select_indexes, :] = -float("inf")
logits[select_indexes, eos_token_id] = 0
return logits
def get_logits_processor(processor: str, logits, *args, **kwargs):
"""
do logit process for given logits.
Args:
processor(str): the type of logit processor
logits(torch.Tensor): input logits
Returns:
logits after process
"""
if processor not in _LOGITS_PROCESSOR_MAP:
logging.warning(f"Unsupported processor {processor}. Fall back to the original logits.")
else:
func = _LOGITS_PROCESSOR_MAP[processor]
logits = func(logits, *args, **kwargs)
return logits