ColossalAI/colossalai/inference/spec/drafter.py

122 lines
4.5 KiB
Python

from typing import Optional, Tuple
import torch
import torch.nn as nn
from transformers import PreTrainedTokenizer
from colossalai.utils import get_current_device
from .struct import DrafterOutput, GlideInput
class Drafter:
"""Container for the Drafter Model (Assistant Model) used in Speculative Decoding.
Args:
model (nn.Module): The drafter model.
tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the drafter model.
device (torch.device): The device for the drafter model.
"""
def __init__(
self,
model: nn.Module,
tokenizer: PreTrainedTokenizer,
device: torch.device = None,
dtype: torch.dtype = torch.float16,
):
self._tokenizer = tokenizer
self._device = device or get_current_device()
self._dtype = dtype
self._drafter_model = model.to(self._device)
self._drafter_model = model.to(self._dtype)
self._drafter_model.eval()
def get_model(self) -> nn.Module:
return self._drafter_model
@staticmethod
def trim_kv_cache(
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], invalid_token_num: int
) -> Tuple[Tuple[torch.FloatTensor]]:
"""Trim the last `invalid_token_num` kv caches.
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values with shape
num_layers x 2 x (bsz x num_heads x seq_len x head_dim)
invalid_token_num (int): The number of invalid tokens to trim.
"""
if past_key_values is None or invalid_token_num < 1:
return past_key_values
trimmed_past_key_values = []
for layer_idx in range(len(past_key_values)):
past_key_value = past_key_values[layer_idx]
trimmed_past_key_values.append(
(
past_key_value[0][:, :, :-invalid_token_num, :],
past_key_value[1][:, :, :-invalid_token_num, :],
)
)
past_key_values = tuple(trimmed_past_key_values)
return past_key_values
@torch.inference_mode()
def speculate(
self,
input_ids: torch.Tensor,
n_spec_tokens: int,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
glide_input: Optional[GlideInput] = None,
) -> DrafterOutput:
"""Generate n_spec_tokens tokens using the drafter model.
Args:
input_ids (torch.Tensor): Input token ids.
n_spec_tokens (int): Number of tokens to speculate.
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence.
glide_input (Optional[GlideInput]): The packed input for glimpsing kv caches of the main model,
when using the glide model as a drafter.
"""
assert n_spec_tokens >= 1, f"Invalid number {n_spec_tokens} to speculate"
# For compatibility with transformers of versions before 4.38.0
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
logits = []
token_ids = []
kwargs = {"return_dict": True, "use_cache": True}
if glide_input:
# required only when using glide model
kwargs["glide_input"] = glide_input
for _ in range(n_spec_tokens):
# update past key values
kwargs["past_key_values"] = past_key_values
outputs = self._drafter_model(input_ids, **kwargs)
next_token_logits = outputs.logits[:, -1, :]
# NOTE Only use greedy search for speculating.
# As the drafter model usually has only a few layers with few parameters,
# introducing sampling will make the speculation unstable and lead to worse performance.
next_token_ids = torch.argmax(next_token_logits, dim=-1)
logits.append(next_token_logits)
token_ids.append(next_token_ids)
if next_token_ids.item() == self._tokenizer.eos_token_id:
# TODO(yuanheng-zhao) support bsz > 1
break
input_ids = next_token_ids[:, None]
past_key_values = outputs.past_key_values
speculated_length = len(token_ids) # For now, only support bsz 1
logits = torch.concat(logits, dim=0)
token_ids = torch.concat(token_ids, dim=-1)
out = DrafterOutput(
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values
)
return out