mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
429 lines
20 KiB
429 lines
20 KiB
from typing import Any, Callable, List, Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from transformers import PreTrainedTokenizer
|
|
|
|
try:
|
|
from transformers.generation_logits_process import (
|
|
LogitsProcessorList,
|
|
TemperatureLogitsWarper,
|
|
TopKLogitsWarper,
|
|
TopPLogitsWarper,
|
|
)
|
|
except ImportError:
|
|
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
|
|
|
|
|
|
def _prepare_logits_processor(
|
|
top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
|
|
) -> LogitsProcessorList:
|
|
"""
|
|
Prepare the logits processor list based on the given parameters.
|
|
|
|
Args:
|
|
top_k (Optional[int]): The number of highest probability logits to keep for each token.
|
|
top_p (Optional[float]): The cumulative probability threshold for selecting tokens.
|
|
temperature (Optional[float]): The temperature value to apply to the logits.
|
|
|
|
Returns:
|
|
LogitsProcessorList: The list of logits processors.
|
|
|
|
"""
|
|
processor_list = LogitsProcessorList()
|
|
if temperature is not None and temperature != 1.0:
|
|
processor_list.append(TemperatureLogitsWarper(temperature))
|
|
if top_k is not None and top_k != 0:
|
|
processor_list.append(TopKLogitsWarper(top_k))
|
|
if top_p is not None and top_p < 1.0:
|
|
processor_list.append(TopPLogitsWarper(top_p))
|
|
return processor_list
|
|
|
|
|
|
def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
|
|
"""
|
|
Check if the sequence generation is finished.
|
|
|
|
Args:
|
|
unfinished_sequences (torch.Tensor): Tensor indicating the unfinished sequences.
|
|
|
|
Returns:
|
|
bool: True if all sequences are finished, False otherwise.
|
|
"""
|
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
|
# consider DP
|
|
unfinished_sequences = unfinished_sequences.clone()
|
|
dist.all_reduce(unfinished_sequences)
|
|
return unfinished_sequences.max() == 0
|
|
|
|
|
|
def update_model_kwargs_fn(outputs: dict, new_mask, **model_kwargs) -> dict:
|
|
"""
|
|
Update the model keyword arguments based on the outputs and new mask.
|
|
|
|
Args:
|
|
outputs (dict): The outputs from the model.
|
|
new_mask: The new attention mask.
|
|
**model_kwargs: Additional model keyword arguments.
|
|
|
|
Returns:
|
|
dict: The updated model keyword arguments.
|
|
"""
|
|
|
|
if "past_key_values" in outputs:
|
|
model_kwargs["past_key_values"] = outputs["past_key_values"]
|
|
else:
|
|
model_kwargs["past_key_values"] = None
|
|
|
|
# update token_type_ids with last value
|
|
if "token_type_ids" in model_kwargs:
|
|
token_type_ids = model_kwargs["token_type_ids"]
|
|
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
|
|
|
# update attention mask
|
|
if "attention_mask" in model_kwargs:
|
|
attention_mask = model_kwargs["attention_mask"]
|
|
model_kwargs["attention_mask"] = torch.cat([attention_mask, new_mask], dim=-1)
|
|
|
|
return model_kwargs
|
|
|
|
|
|
def prepare_inputs_fn(input_ids: torch.Tensor, pad_token_id: int, **model_kwargs) -> dict:
|
|
model_kwargs["input_ids"] = input_ids
|
|
return model_kwargs
|
|
|
|
|
|
def _sample(
|
|
model: Any,
|
|
input_ids: torch.Tensor,
|
|
max_length: int,
|
|
early_stopping: bool = True,
|
|
eos_token_id: Optional[int] = None,
|
|
pad_token_id: Optional[int] = None,
|
|
stop_token_ids: Optional[List[int]] = None,
|
|
top_k: Optional[int] = None,
|
|
top_p: Optional[float] = None,
|
|
temperature: Optional[float] = None,
|
|
max_new_tokens: int = None,
|
|
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
|
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
|
stream_interval: int = 2,
|
|
**model_kwargs,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Generates new tokens using the given model and input_ids.
|
|
|
|
Args:
|
|
model (Any): The model used for token generation.
|
|
input_ids (torch.Tensor): The input tensor containing the initial tokens.
|
|
max_length (int): The maximum length of the generated tokens.
|
|
early_stopping (bool, optional): Whether to stop generating tokens early if all sequences are finished. Defaults to True.
|
|
eos_token_id (int, optional): The ID of the end-of-sequence token. Defaults to None.
|
|
pad_token_id (int, optional): The ID of the padding token. Defaults to None.
|
|
stop_token_ids (List[int], optional): A list of token IDs that, if encountered, will stop the generation process. Defaults to None.
|
|
top_k (int, optional): The number of top-k tokens to consider during sampling. Defaults to None.
|
|
top_p (float, optional): The cumulative probability threshold for top-p sampling. Defaults to None.
|
|
temperature (float, optional): The temperature value for token sampling. Defaults to None.
|
|
max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to None.
|
|
prepare_inputs_fn (Callable[[torch.Tensor, Any], dict], optional): A function to prepare the model inputs. Defaults to None.
|
|
update_model_kwargs_fn (Callable[[dict, Any], dict], optional): A function to update the model kwargs. Defaults to None.
|
|
stream_interval (int, optional): The interval for streaming generation. Defaults to 2.
|
|
**model_kwargs: Additional keyword arguments for the model.
|
|
|
|
Returns:
|
|
torch.Tensor: The tensor containing the generated tokens.
|
|
"""
|
|
context_length = input_ids.size(1)
|
|
if max_new_tokens is None:
|
|
max_new_tokens = max_length - context_length
|
|
if context_length + max_new_tokens > max_length or max_new_tokens == 0:
|
|
return input_ids
|
|
|
|
logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
|
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
|
past = None
|
|
for i in range(context_length, context_length + max_new_tokens):
|
|
# Calculate attention mask
|
|
if "attention_mask" not in model_kwargs:
|
|
model_kwargs["attention_mask"] = input_ids.ne(pad_token_id)
|
|
model_inputs = (
|
|
prepare_inputs_fn(input_ids, past=past, **model_kwargs)
|
|
if prepare_inputs_fn is not None
|
|
else {"input_ids": input_ids, "attention_mask": input_ids.ne(pad_token_id)}
|
|
)
|
|
outputs = model(**model_inputs)
|
|
|
|
if "past_key_values" in outputs:
|
|
past = outputs.past_key_values
|
|
elif "mems" in outputs:
|
|
past = outputs.mems
|
|
|
|
# NOTE: this is correct only in left padding mode
|
|
next_token_logits = outputs["logits"][:, -1, :]
|
|
next_token_logits = logits_processor(input_ids, next_token_logits)
|
|
|
|
# Sample
|
|
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
|
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
|
|
# Finished sentences should have their next token be a padding token
|
|
if eos_token_id is not None:
|
|
assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
|
|
|
# Update generated ids, model inputs for next step
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
|
|
|
if update_model_kwargs_fn is not None:
|
|
model_kwargs = update_model_kwargs_fn(outputs, model_kwargs)
|
|
|
|
# If eos_token was found in one sentence, set sentence to finished
|
|
if eos_token_id is not None:
|
|
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
|
|
|
if stop_token_ids is not None:
|
|
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
|
|
tokens_to_check = input_ids[:, -len(stop_token_ids) :]
|
|
unfinished_sequences = unfinished_sequences.mul(
|
|
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
|
|
)
|
|
|
|
# Stop when each sentence is finished if early_stopping=True
|
|
if (early_stopping and _is_sequence_finished(unfinished_sequences)) or i == context_length + max_new_tokens - 1:
|
|
if i == context_length + max_new_tokens - 1:
|
|
# Force to end with stop token ids
|
|
input_ids[input_ids[:, -1] != pad_token_id, -len(stop_token_ids) :] = (
|
|
torch.LongTensor(stop_token_ids).to(input_ids.device).long()
|
|
)
|
|
return input_ids
|
|
|
|
|
|
@torch.inference_mode()
|
|
def generate(
|
|
model: Any,
|
|
input_ids: torch.Tensor,
|
|
tokenizer: PreTrainedTokenizer,
|
|
max_length: int,
|
|
num_beams: int = 1,
|
|
do_sample: bool = True,
|
|
early_stopping: bool = True,
|
|
top_k: Optional[int] = None,
|
|
top_p: Optional[float] = None,
|
|
temperature: Optional[float] = None,
|
|
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
|
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
|
**model_kwargs,
|
|
) -> torch.Tensor:
|
|
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
|
|
|
|
Args:
|
|
model (nn.Module): model
|
|
input_ids (torch.Tensor): input sequence
|
|
max_length (int): max length of the returned sequence
|
|
num_beams (int, optional): number of beams. Defaults to 1.
|
|
do_sample (bool, optional): whether to do sample. Defaults to True.
|
|
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
|
|
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
|
|
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
|
|
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
|
|
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
|
|
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
|
|
"""
|
|
assert tokenizer.padding_side == "left", "Current generation only supports left padding."
|
|
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
|
|
is_sample_gen_mode = (num_beams == 1) and do_sample is True
|
|
is_beam_gen_mode = (num_beams > 1) and do_sample is False
|
|
if is_greedy_gen_mode:
|
|
raise NotImplementedError
|
|
elif is_sample_gen_mode:
|
|
# Run sample
|
|
res = _sample(
|
|
model,
|
|
input_ids,
|
|
max_length,
|
|
early_stopping=early_stopping,
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
pad_token_id=tokenizer.pad_token_id,
|
|
top_k=top_k,
|
|
top_p=top_p,
|
|
temperature=temperature,
|
|
prepare_inputs_fn=prepare_inputs_fn,
|
|
update_model_kwargs_fn=update_model_kwargs_fn,
|
|
**model_kwargs,
|
|
)
|
|
return res
|
|
elif is_beam_gen_mode:
|
|
raise NotImplementedError
|
|
else:
|
|
raise ValueError("Unsupported generation mode")
|
|
|
|
|
|
def _sample_streaming(
|
|
model: Any,
|
|
input_ids: torch.Tensor,
|
|
max_length: int,
|
|
early_stopping: bool = False,
|
|
eos_token_id: Optional[int] = None,
|
|
pad_token_id: Optional[int] = None,
|
|
stop_token_ids: Optional[List[int]] = None,
|
|
top_k: Optional[int] = None,
|
|
top_p: Optional[float] = None,
|
|
temperature: Optional[float] = None,
|
|
max_new_tokens: int = None,
|
|
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
|
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
|
stream_interval: int = 2,
|
|
**model_kwargs,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Generates new tokens using a streaming approach.
|
|
|
|
Args:
|
|
model (Any): The model used for token generation.
|
|
input_ids (torch.Tensor): The input tensor containing the initial tokens.
|
|
max_length (int): The maximum length of the generated sequence.
|
|
early_stopping (bool, optional): Whether to stop generating tokens for a sequence if it is finished. Defaults to False.
|
|
eos_token_id (int, optional): The ID of the end-of-sequence token. Defaults to None.
|
|
pad_token_id (int, optional): The ID of the padding token. Defaults to None.
|
|
stop_token_ids (List[int], optional): A list of token IDs that, if encountered, will mark the sequence as finished. Defaults to None.
|
|
top_k (int, optional): The number of top-k tokens to consider during sampling. Defaults to None.
|
|
top_p (float, optional): The cumulative probability threshold for top-p sampling. Defaults to None.
|
|
temperature (float, optional): The temperature value for sampling. Defaults to None.
|
|
max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to None.
|
|
prepare_inputs_fn (Callable[[torch.Tensor, Any], dict], optional): A function to prepare the model inputs. Defaults to None.
|
|
update_model_kwargs_fn (Callable[[dict, Any], dict], optional): A function to update the model keyword arguments. Defaults to None.
|
|
stream_interval (int, optional): The interval at which to yield the generated tokens. Defaults to 2.
|
|
**model_kwargs: Additional keyword arguments to be passed to the model.
|
|
|
|
Yields:
|
|
torch.Tensor: The generated tokens at each step.
|
|
|
|
Returns:
|
|
torch.Tensor: The final generated tokens.
|
|
"""
|
|
|
|
context_length = input_ids.size(1)
|
|
if max_new_tokens is None:
|
|
max_new_tokens = max_length - context_length
|
|
if context_length + max_new_tokens > max_length or max_new_tokens == 0:
|
|
return input_ids
|
|
|
|
logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
|
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
|
|
|
past = None
|
|
for i in range(context_length, context_length + max_new_tokens):
|
|
# calculate attention mask
|
|
if "attention_mask" not in model_kwargs:
|
|
model_kwargs["attention_mask"] = input_ids.ne(pad_token_id)
|
|
model_inputs = (
|
|
prepare_inputs_fn(input_ids, past=past, **model_kwargs)
|
|
if prepare_inputs_fn is not None
|
|
else {"input_ids": input_ids, "attention_mask": input_ids.ne(pad_token_id)}
|
|
)
|
|
outputs = model(**model_inputs)
|
|
if "past_key_values" in outputs:
|
|
past = outputs.past_key_values
|
|
elif "mems" in outputs:
|
|
past = outputs.mems
|
|
|
|
# NOTE: this is correct only in left padding mode
|
|
next_token_logits = outputs["logits"][:, -1, :]
|
|
next_token_logits = logits_processor(input_ids, next_token_logits)
|
|
# sample
|
|
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
|
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
|
|
# finished sentences should have their next token be a padding token
|
|
if eos_token_id is not None:
|
|
assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
|
|
|
# update generated ids, model inputs for next step
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
|
|
|
if update_model_kwargs_fn is not None:
|
|
model_kwargs = update_model_kwargs_fn(outputs, model_kwargs)
|
|
|
|
# if eos_token was found in one sentence, set sentence to finished
|
|
if eos_token_id is not None:
|
|
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
|
|
|
if stop_token_ids is not None:
|
|
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
|
|
tokens_to_check = input_ids[:, -len(stop_token_ids) :]
|
|
unfinished_sequences = unfinished_sequences.mul(
|
|
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
|
|
)
|
|
|
|
# Stop when each sentence is finished if early_stopping=True
|
|
if (
|
|
(early_stopping and _is_sequence_finished(unfinished_sequences))
|
|
or (i - context_length) % stream_interval == 0
|
|
or i == context_length + max_new_tokens - 1
|
|
):
|
|
yield input_ids
|
|
if early_stopping and _is_sequence_finished(unfinished_sequences):
|
|
break
|
|
|
|
|
|
@torch.inference_mode()
|
|
def generate_streaming(
|
|
model: Any,
|
|
input_ids: torch.Tensor,
|
|
tokenizer: PreTrainedTokenizer,
|
|
max_length: int,
|
|
num_beams: int = 1,
|
|
do_sample: bool = True,
|
|
early_stopping: bool = False,
|
|
top_k: Optional[int] = None,
|
|
top_p: Optional[float] = None,
|
|
temperature: Optional[float] = None,
|
|
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
|
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
|
**model_kwargs,
|
|
):
|
|
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
|
|
|
|
Args:
|
|
model (nn.Module): model
|
|
input_ids (torch.Tensor): input sequence
|
|
max_length (int): max length of the returned sequence
|
|
num_beams (int, optional): number of beams. Defaults to 1.
|
|
do_sample (bool, optional): whether to do sample. Defaults to True.
|
|
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
|
|
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
|
|
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
|
|
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
|
|
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
|
|
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
|
|
"""
|
|
assert tokenizer.padding_side == "left", "Current generation only supports left padding."
|
|
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
|
|
is_sample_gen_mode = (num_beams == 1) and do_sample is True
|
|
is_beam_gen_mode = (num_beams > 1) and do_sample is False
|
|
if is_greedy_gen_mode:
|
|
# run greedy search
|
|
raise NotImplementedError
|
|
elif is_sample_gen_mode:
|
|
# run sample
|
|
for res in _sample_streaming(
|
|
model,
|
|
input_ids,
|
|
max_length,
|
|
early_stopping=early_stopping,
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
pad_token_id=tokenizer.pad_token_id,
|
|
top_k=top_k,
|
|
top_p=top_p,
|
|
temperature=temperature,
|
|
prepare_inputs_fn=prepare_inputs_fn,
|
|
update_model_kwargs_fn=update_model_kwargs_fn,
|
|
**model_kwargs,
|
|
):
|
|
yield res
|
|
elif is_beam_gen_mode:
|
|
raise NotImplementedError
|
|
else:
|
|
raise ValueError("Unsupported generation mode")
|