mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
138 lines
6.4 KiB
138 lines
6.4 KiB
from typing import Any, Callable, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
try:
|
|
from transformers.generation_logits_process import (
|
|
LogitsProcessorList,
|
|
TemperatureLogitsWarper,
|
|
TopKLogitsWarper,
|
|
TopPLogitsWarper,
|
|
)
|
|
except ImportError:
|
|
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
|
|
|
|
|
|
def prepare_logits_processor(top_k: Optional[int] = None,
|
|
top_p: Optional[float] = None,
|
|
temperature: Optional[float] = None) -> LogitsProcessorList:
|
|
processor_list = LogitsProcessorList()
|
|
if temperature is not None and temperature != 1.0:
|
|
processor_list.append(TemperatureLogitsWarper(temperature))
|
|
if top_k is not None and top_k != 0:
|
|
processor_list.append(TopKLogitsWarper(top_k))
|
|
if top_p is not None and top_p < 1.0:
|
|
processor_list.append(TopPLogitsWarper(top_p))
|
|
return processor_list
|
|
|
|
|
|
def sample(model: nn.Module,
|
|
input_ids: torch.Tensor,
|
|
max_length: int,
|
|
early_stopping: bool = False,
|
|
eos_token_id: Optional[int] = None,
|
|
pad_token_id: Optional[int] = None,
|
|
top_k: Optional[int] = None,
|
|
top_p: Optional[float] = None,
|
|
temperature: Optional[float] = None,
|
|
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
|
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
|
**model_kwargs) -> torch.Tensor:
|
|
if input_ids.size(1) >= max_length:
|
|
return input_ids
|
|
|
|
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
|
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
|
|
|
for _ in range(input_ids.size(1), max_length):
|
|
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
|
|
'input_ids': input_ids
|
|
}
|
|
outputs = model(**model_inputs)
|
|
|
|
next_token_logits = outputs['logits'][:, -1, :]
|
|
# pre-process distribution
|
|
next_token_logits = logits_processor(input_ids, next_token_logits)
|
|
# sample
|
|
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
|
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
|
|
# finished sentences should have their next token be a padding token
|
|
if eos_token_id is not None:
|
|
if pad_token_id is None:
|
|
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
|
|
|
# update generated ids, model inputs for next step
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
|
if update_model_kwargs_fn is not None:
|
|
model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs)
|
|
|
|
# if eos_token was found in one sentence, set sentence to finished
|
|
if eos_token_id is not None:
|
|
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
|
|
|
# stop when each sentence is finished if early_stopping=True
|
|
if early_stopping and unfinished_sequences.max() == 0:
|
|
break
|
|
|
|
return input_ids
|
|
|
|
|
|
def generate(model: nn.Module,
|
|
input_ids: torch.Tensor,
|
|
max_length: int,
|
|
num_beams: int = 1,
|
|
do_sample: bool = True,
|
|
early_stopping: bool = False,
|
|
eos_token_id: Optional[int] = None,
|
|
pad_token_id: Optional[int] = None,
|
|
top_k: Optional[int] = None,
|
|
top_p: Optional[float] = None,
|
|
temperature: Optional[float] = None,
|
|
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
|
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
|
**model_kwargs) -> torch.Tensor:
|
|
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
|
|
|
|
Args:
|
|
model (nn.Module): model
|
|
input_ids (torch.Tensor): input sequence
|
|
max_length (int): max length of the returned sequence
|
|
num_beams (int, optional): number of beams. Defaults to 1.
|
|
do_sample (bool, optional): whether to do sample. Defaults to True.
|
|
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
|
|
eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None.
|
|
pad_token_id (Optional[int], optional): pad token id. Defaults to None.
|
|
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
|
|
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
|
|
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
|
|
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
|
|
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
|
|
"""
|
|
is_greedy_gen_mode = ((num_beams == 1) and do_sample is False)
|
|
is_sample_gen_mode = ((num_beams == 1) and do_sample is True)
|
|
is_beam_gen_mode = ((num_beams > 1) and do_sample is False)
|
|
if is_greedy_gen_mode:
|
|
# run greedy search
|
|
raise NotImplementedError
|
|
elif is_sample_gen_mode:
|
|
# run sample
|
|
return sample(model,
|
|
input_ids,
|
|
max_length,
|
|
early_stopping=early_stopping,
|
|
eos_token_id=eos_token_id,
|
|
pad_token_id=pad_token_id,
|
|
top_k=top_k,
|
|
top_p=top_p,
|
|
temperature=temperature,
|
|
prepare_inputs_fn=prepare_inputs_fn,
|
|
update_model_kwargs_fn=update_model_kwargs_fn,
|
|
**model_kwargs)
|
|
elif is_beam_gen_mode:
|
|
raise NotImplementedError
|
|
else:
|
|
raise ValueError("Unsupported generation mode")
|