mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
111 lines
4.2 KiB
111 lines
4.2 KiB
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
from transformers.generation import GenerationConfig |
|
|
|
from colossalai.inference.logit_processors import get_logits_processor |
|
|
|
|
|
def greedy_sample( |
|
logprobs: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" |
|
Sample tokens greedyly. |
|
""" |
|
results = torch.argmax(logprobs, dim=-1) |
|
return results |
|
|
|
|
|
def multinomial_sample( |
|
probs: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" |
|
Sample tokens in a random phase. |
|
""" |
|
random_results = torch.multinomial(probs, num_samples=1).squeeze(1) |
|
return random_results |
|
|
|
|
|
def beam_search_sample( |
|
beam_width: int, |
|
logprobs: torch.Tensor, |
|
is_prompt: bool = False, |
|
) -> List[Tuple[List[int], List[int]]]: |
|
""" |
|
Sample tokens with beam search. |
|
We sample 2 * beam_width candidates to make sure that with high probability we can get `beam_width` candidates in addition to |
|
the finished sequences for the next iteration. |
|
|
|
ref: |
|
https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563 |
|
for details. See also HF reference: |
|
https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065 |
|
|
|
# NOTE: this beam search sample function is wrong now. |
|
""" |
|
|
|
results = [] |
|
if is_prompt: |
|
# Prompt phase. |
|
parent_ids = [0] * (2 * beam_width) |
|
_, next_token_ids = torch.topk(logprobs[0], 2 * beam_width) |
|
next_token_ids = next_token_ids.tolist() |
|
else: |
|
# Generation phase. |
|
# cumulative_logprobs = [seq_data[seq_id].cumulative_logprob for seq_id in seq_ids] |
|
cumulative_logprobs = torch.tensor(logprobs, dtype=torch.float, device=seq_group_logprobs.device) |
|
seq_group_logprobs = seq_group_logprobs + cumulative_logprobs.unsqueeze(dim=1) |
|
_, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width) |
|
|
|
results.append((next_token_ids, parent_ids)) |
|
return results |
|
|
|
|
|
def search_tokens( |
|
generation_config: Union[GenerationConfig, dict], |
|
logits, |
|
is_prompt: bool = False, |
|
batch_token_ids: Optional[List[List[int]]] = None, |
|
): |
|
""" |
|
Sample tokens for finished requests. |
|
""" |
|
# NOTE: need to decide the granularity to process logits (sequence or batch) |
|
|
|
# convert GenerationConfig to dict |
|
# temporary fix for compatibility with the usage of RPCInferenceEngine |
|
if isinstance(generation_config, GenerationConfig): |
|
generation_config = generation_config.to_dict() |
|
|
|
if (repetition_penalty := generation_config.get("repetition_penalty", 1.0)) != 1.0: |
|
logits = get_logits_processor("repetition_penalty", logits, repetition_penalty, batch_token_ids) |
|
if (no_repeat_ngram_size := generation_config.get("no_repeat_ngram_size", 0)) > 0: |
|
logits = get_logits_processor("no_repeat_ngram_size", logits, no_repeat_ngram_size, batch_token_ids) |
|
if (forced_eos_token_id := generation_config.get("forced_eos_token_id", None)) is not None: |
|
sequence_lengths = [len(batch_token_ids[i]) for i in range(len(batch_token_ids))] |
|
max_out_lengths = [generation_config.max_length for _ in range(len(batch_token_ids))] |
|
logits = get_logits_processor( |
|
"forced_eos_token_id", logits, sequence_lengths, max_out_lengths, forced_eos_token_id |
|
) |
|
|
|
if generation_config.get("do_sample"): |
|
if (temperature := generation_config.get("temperature", 1.0)) != 1.0: |
|
logits = get_logits_processor("temperature", logits, temperature) |
|
if (top_k := generation_config.get("top_k", 0)) != 0: |
|
logits = get_logits_processor("top_k", logits, top_k) |
|
if (top_p := generation_config.get("top_p", 1.0)) < 1.0: |
|
logits = get_logits_processor("top_p", logits, top_p) |
|
|
|
# calculate probs |
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float) |
|
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) |
|
|
|
# sample the next tokens |
|
if generation_config.get("num_beams", 1) != 1: |
|
raise NotImplementedError("Beam search is not supported yet.") |
|
if generation_config.get("do_sample", False): |
|
sample_tokens = multinomial_sample(probs) |
|
else: |
|
sample_tokens = greedy_sample(logprobs) |
|
|
|
return sample_tokens
|
|
|