mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
189 lines
6.5 KiB
189 lines
6.5 KiB
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py |
|
import logging |
|
from typing import List, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
_LOGITS_PROCESSOR_MAP = {} |
|
|
|
|
|
def register_logits_processor(process_type): |
|
""" |
|
register flops computation function for operation. |
|
""" |
|
|
|
def register(func): |
|
global _LOGITS_PROCESSOR_MAP |
|
_LOGITS_PROCESSOR_MAP[process_type] = func |
|
return func |
|
|
|
return register |
|
|
|
|
|
@register_logits_processor("no_repeat_ngram_size") |
|
def apply_no_repeat_ngram_size(logits, ngram_size: int, batch_token_ids: List[List[int]]): |
|
""" |
|
enforces no repetition of n-grams to avoid repetitions of word sequences. |
|
""" |
|
|
|
if not isinstance(ngram_size, int) or ngram_size < 0: |
|
raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.") |
|
|
|
if ngram_size != 0: |
|
batch_size = len(batch_token_ids) |
|
|
|
for batch_id in range(batch_size): |
|
current_token_ids = batch_token_ids[batch_id] |
|
current_len = len(current_token_ids) |
|
if current_len + 1 < ngram_size: |
|
continue |
|
|
|
ngrams_dict = {} |
|
|
|
for ngram in zip(*[current_token_ids[i:] for i in range(ngram_size)]): |
|
prev_ngram_tuple = tuple(ngram[:-1]) |
|
ngrams_dict[prev_ngram_tuple] = ngrams_dict.get(prev_ngram_tuple, []) + [ngram[-1]] |
|
|
|
prev_ngrams = tuple(current_token_ids[current_len + 1 - ngram_size : current_len]) |
|
banned_token = ngrams_dict.get(prev_ngrams, []) |
|
|
|
logits[batch_id, banned_token] = -float("inf") |
|
|
|
return logits |
|
|
|
|
|
@register_logits_processor("repetition_penalty") |
|
def apply_repetition_penalty(logits, penalty: float, batch_token_ids: List[List[int]]): |
|
""" |
|
apply the penalty to the tokens present in the prompt. |
|
""" |
|
|
|
if not isinstance(penalty, float) or not (penalty > 0): |
|
raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.") |
|
|
|
logits_list = [] |
|
|
|
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels. |
|
if penalty != 1.0: |
|
for batch_id in range(len(batch_token_ids)): |
|
current_logit = logits[batch_id] |
|
current_token = torch.tensor(batch_token_ids[batch_id], dtype=torch.long, device=logits.device) |
|
|
|
curretn_socre = torch.gather(current_logit, 0, current_token) |
|
curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty) |
|
logits_list.append(current_logit.scatter(0, current_token, curretn_socre)) |
|
|
|
logits = torch.stack(logits_list) |
|
|
|
return logits |
|
|
|
|
|
@register_logits_processor("temperature") |
|
def apply_temperature(logits, temperature: float): |
|
""" |
|
apply temperature scaling. |
|
""" |
|
|
|
if not isinstance(temperature, float) or not (0.0 < temperature <= 1.0): |
|
except_msg = f"'temperature={temperature}' should be a strictly positive float, less than or equal to 1.0 and greater than 0." |
|
if temperature == 0.0: |
|
except_msg += "if you want to use greedy decoding strategies, set `do_sample=False`." |
|
raise ValueError(except_msg) |
|
|
|
return logits if temperature == 1.0 else logits / temperature |
|
|
|
|
|
@register_logits_processor("top_k") |
|
def apply_top_k(logits, top_k: int): |
|
""" |
|
top_k logit processor |
|
""" |
|
|
|
if not isinstance(top_k, int) or top_k <= 0: |
|
raise ValueError(f"`top_k` should be a strictly positive integer, but got {top_k}.") |
|
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
logits[indices_to_remove] = -float("inf") |
|
return logits |
|
|
|
|
|
@register_logits_processor("top_p") |
|
def apply_top_p(logits, top_p: float): |
|
""" |
|
top_p logit processor |
|
""" |
|
|
|
if top_p < 0 or top_p > 1.0: |
|
raise ValueError(f"`top_p` should be a float > 0 and < 1, but got {top_p}.") |
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
sorted_indices_to_remove = torch.roll(sorted_indices_to_remove, 1, -1) |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) |
|
logits[indices_to_remove] = -float("inf") |
|
return logits |
|
|
|
|
|
@register_logits_processor("forced_eos_token_id") |
|
def apply_forced_eos_token_id( |
|
logits: torch.Tensor, |
|
sequence_lengths: Union[torch.Tensor, List[int]], |
|
max_lengths: Union[torch.Tensor, List[int]], |
|
eos_token_id: Union[int, List[int]], |
|
): |
|
""" |
|
Enforces the specified token as the last generated token when the maximum output length |
|
is reached. Notice that the maximum output lengths for different sequences, even if they're |
|
in the same batch, can be different. |
|
|
|
Args: |
|
logits(torch.Tensor): logits |
|
sequence_lengths(torch.Tensor): sequence lengths including prompt and output tokens |
|
max_lengths(torch.Tensor): the maximum length for each sequence |
|
eos_token_id(Union[int, List[int]]): forced eos token id |
|
""" |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
if isinstance(sequence_lengths, torch.Tensor): |
|
sequence_lengths = sequence_lengths.tolist() |
|
if isinstance(max_lengths, torch.Tensor): |
|
max_lengths = max_lengths.tolist() |
|
|
|
select_indexes = [] |
|
num_sequences = logits.shape[0] |
|
sequence_lengths = sequence_lengths[:num_sequences] |
|
max_lengths = max_lengths[:num_sequences] |
|
for i, (sequence_length, max_out_length) in enumerate(zip(sequence_lengths, max_lengths)): |
|
if sequence_length == max_out_length - 1: |
|
select_indexes.append(i) |
|
if select_indexes: |
|
logits[select_indexes, :] = -float("inf") |
|
logits[select_indexes, eos_token_id] = 0 |
|
|
|
return logits |
|
|
|
|
|
def get_logits_processor(processor: str, logits, *args, **kwargs): |
|
""" |
|
do logit process for given logits. |
|
|
|
Args: |
|
processor(str): the type of logit processor |
|
logits(torch.Tensor): input logits |
|
|
|
Returns: |
|
logits after process |
|
""" |
|
if processor not in _LOGITS_PROCESSOR_MAP: |
|
logging.warning(f"Unsupported processor {processor}. Fall back to the original logits.") |
|
else: |
|
func = _LOGITS_PROCESSOR_MAP[processor] |
|
logits = func(logits, *args, **kwargs) |
|
|
|
return logits
|
|
|