mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
122 lines
4.5 KiB
122 lines
4.5 KiB
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import PreTrainedTokenizer
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
from .struct import DrafterOutput, GlideInput
|
|
|
|
|
|
class Drafter:
|
|
"""Container for the Drafter Model (Assistant Model) used in Speculative Decoding.
|
|
|
|
Args:
|
|
model (nn.Module): The drafter model.
|
|
tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the drafter model.
|
|
device (torch.device): The device for the drafter model.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
tokenizer: PreTrainedTokenizer,
|
|
device: torch.device = None,
|
|
dtype: torch.dtype = torch.float16,
|
|
):
|
|
self._tokenizer = tokenizer
|
|
self._device = device or get_current_device()
|
|
self._dtype = dtype
|
|
self._drafter_model = model.to(self._device)
|
|
self._drafter_model = model.to(self._dtype)
|
|
self._drafter_model.eval()
|
|
|
|
def get_model(self) -> nn.Module:
|
|
return self._drafter_model
|
|
|
|
@staticmethod
|
|
def trim_kv_cache(
|
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], invalid_token_num: int
|
|
) -> Tuple[Tuple[torch.FloatTensor]]:
|
|
"""Trim the last `invalid_token_num` kv caches.
|
|
|
|
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values with shape
|
|
num_layers x 2 x (bsz x num_heads x seq_len x head_dim)
|
|
invalid_token_num (int): The number of invalid tokens to trim.
|
|
"""
|
|
if past_key_values is None or invalid_token_num < 1:
|
|
return past_key_values
|
|
|
|
trimmed_past_key_values = []
|
|
for layer_idx in range(len(past_key_values)):
|
|
past_key_value = past_key_values[layer_idx]
|
|
trimmed_past_key_values.append(
|
|
(
|
|
past_key_value[0][:, :, :-invalid_token_num, :],
|
|
past_key_value[1][:, :, :-invalid_token_num, :],
|
|
)
|
|
)
|
|
past_key_values = tuple(trimmed_past_key_values)
|
|
return past_key_values
|
|
|
|
@torch.inference_mode()
|
|
def speculate(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
n_spec_tokens: int,
|
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
|
glide_input: Optional[GlideInput] = None,
|
|
) -> DrafterOutput:
|
|
"""Generate n_spec_tokens tokens using the drafter model.
|
|
|
|
Args:
|
|
input_ids (torch.Tensor): Input token ids.
|
|
n_spec_tokens (int): Number of tokens to speculate.
|
|
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence.
|
|
glide_input (Optional[GlideInput]): The packed input for glimpsing kv caches of the main model,
|
|
when using the glide model as a drafter.
|
|
"""
|
|
assert n_spec_tokens >= 1, f"Invalid number {n_spec_tokens} to speculate"
|
|
|
|
# For compatibility with transformers of versions before 4.38.0
|
|
if input_ids.dim() == 1:
|
|
input_ids = input_ids.unsqueeze(0)
|
|
|
|
logits = []
|
|
token_ids = []
|
|
|
|
kwargs = {"return_dict": True, "use_cache": True}
|
|
if glide_input:
|
|
# required only when using glide model
|
|
kwargs["glide_input"] = glide_input
|
|
|
|
for _ in range(n_spec_tokens):
|
|
# update past key values
|
|
kwargs["past_key_values"] = past_key_values
|
|
|
|
outputs = self._drafter_model(input_ids, **kwargs)
|
|
next_token_logits = outputs.logits[:, -1, :]
|
|
|
|
# NOTE Only use greedy search for speculating.
|
|
# As the drafter model usually has only a few layers with few parameters,
|
|
# introducing sampling will make the speculation unstable and lead to worse performance.
|
|
next_token_ids = torch.argmax(next_token_logits, dim=-1)
|
|
|
|
logits.append(next_token_logits)
|
|
token_ids.append(next_token_ids)
|
|
if next_token_ids.item() == self._tokenizer.eos_token_id:
|
|
# TODO(yuanheng-zhao) support bsz > 1
|
|
break
|
|
input_ids = next_token_ids[:, None]
|
|
past_key_values = outputs.past_key_values
|
|
|
|
speculated_length = len(token_ids) # For now, only support bsz 1
|
|
logits = torch.concat(logits, dim=0)
|
|
token_ids = torch.concat(token_ids, dim=-1)
|
|
|
|
out = DrafterOutput(
|
|
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values
|
|
)
|
|
return out
|