You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/inference/sampler.py

62 lines
1.9 KiB

from typing import List, Tuple
import torch
def greedy_sample(
generation_config,
logprobs: torch.Tensor,
) -> torch.Tensor:
"""
Sample tokens greedyly.
"""
results = torch.argmax(logprobs, dim=-1)
return results
def multinomial_sample(
generation_config,
probs: torch.Tensor,
) -> torch.Tensor:
"""
Sample tokens in a random phase.
"""
random_results = torch.multinomial(probs, num_samples=1).squeeze(1)
return random_results
def beam_search_sample(
generation_config,
logprobs: torch.Tensor,
is_prompt: bool = False,
) -> List[Tuple[List[int], List[int]]]:
"""
Sample tokens with beam search.
We sample 2 * beam_width candidates to make sure that with high probability we can get `beam_width` candidates in addition to
the finished sequences for the next iteration.
ref:
https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
for details. See also HF reference:
https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
# NOTE: this beam search sample function is wrong now.
"""
beam_width = generation_config.num_beams
results = []
if is_prompt:
# Prompt phase.
parent_ids = [0] * (2 * beam_width)
_, next_token_ids = torch.topk(logprobs[0], 2 * beam_width)
next_token_ids = next_token_ids.tolist()
else:
# Generation phase.
# cumulative_logprobs = [seq_data[seq_id].cumulative_logprob for seq_id in seq_ids]
cumulative_logprobs = torch.tensor(logprobs, dtype=torch.float, device=seq_group_logprobs.device)
seq_group_logprobs = seq_group_logprobs + cumulative_logprobs.unsqueeze(dim=1)
_, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width)
results.append((next_token_ids, parent_ids))
return results