mirror of https://github.com/hpcaitech/ColossalAI
[Inference/SpecDec] Add Basic Drafter Model Container (#5405)
* [Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399) fix dependency in pytest * add drafter model container (basic ver)feat/speculative-decoding
parent
d63c469f45
commit
5a9b05f7b2
|
@ -0,0 +1,4 @@
|
||||||
|
from .drafter import Drafter
|
||||||
|
from .struct import DrafterOutput
|
||||||
|
|
||||||
|
__all__ = ["Drafter", "DrafterOutput"]
|
|
@ -0,0 +1,142 @@
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
from .struct import DrafterOutput
|
||||||
|
|
||||||
|
|
||||||
|
class Drafter:
|
||||||
|
"""Container for the Drafter Model (Assistant Model) used in Speculative Decoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The drafter model.
|
||||||
|
tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the drafter model.
|
||||||
|
max_spec_num (int): The maximum number of tokens to speculate.
|
||||||
|
device (torch.device): The device for the drafter model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, model: nn.Module, tokenizer: PreTrainedTokenizer, max_spec_num: int, device: torch.device = None
|
||||||
|
):
|
||||||
|
self._drafter_model = model
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
self.max_spec_num = max_spec_num
|
||||||
|
self.do_sample = False
|
||||||
|
self.sample_fn = None
|
||||||
|
self._device = device or get_current_device()
|
||||||
|
self._past_key_values = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def past_key_values(self) -> Optional[Tuple[Tuple[torch.FloatTensor]]]:
|
||||||
|
return self._past_key_values
|
||||||
|
|
||||||
|
# Debug usage for now
|
||||||
|
@property
|
||||||
|
def past_key_values_shape(self):
|
||||||
|
if self._past_key_values is None:
|
||||||
|
return []
|
||||||
|
return self._past_key_values[0][0].shape
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self._drafter_model
|
||||||
|
|
||||||
|
def reset_sample_method(self, sample_fn: callable) -> None:
|
||||||
|
self.do_sample = True
|
||||||
|
self.sample_fn = sample_fn
|
||||||
|
|
||||||
|
def clear_sample_method(self) -> None:
|
||||||
|
self.do_sample = False
|
||||||
|
self.sample_fn = None
|
||||||
|
|
||||||
|
def reset_max_spec_num(self, n: int) -> None:
|
||||||
|
assert isinstance(n, int) and n > 1
|
||||||
|
self.max_spec_num = n
|
||||||
|
|
||||||
|
def reset_past_key_values(self, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None) -> None:
|
||||||
|
self._past_key_values = past_key_values
|
||||||
|
|
||||||
|
def trim_kv_cache(self, invalid_token_num) -> Tuple[Tuple[torch.FloatTensor]]:
|
||||||
|
# Tuple of kv cache tensors: num_layers x 2 x (bsz x num_heads x seq_len x head_dim)
|
||||||
|
# Trim the last `invalid_token_num` kv caches
|
||||||
|
# The verifier (main model) might reject `invalid_token_num` tokens,
|
||||||
|
# and so that we have to trim the invalid tokens for the kv cache of the drafter model.
|
||||||
|
assert self._past_key_values is not None
|
||||||
|
trimmed_past_key_values = []
|
||||||
|
for layer_idx in range(len(self._past_key_values)):
|
||||||
|
past_key_value = self._past_key_values[layer_idx]
|
||||||
|
trimmed_past_key_values.append(
|
||||||
|
(
|
||||||
|
past_key_value[0][:, :, :-invalid_token_num, :],
|
||||||
|
past_key_value[1][:, :, :-invalid_token_num, :],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._past_key_values = tuple(trimmed_past_key_values)
|
||||||
|
return self._past_key_values
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def speculate(
|
||||||
|
self, input_ids: torch.Tensor, n: int, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None
|
||||||
|
) -> DrafterOutput:
|
||||||
|
"""Generate n tokens using the drafter model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids (torch.Tensor): Input token ids.
|
||||||
|
n (int): Number of tokens to speculate.
|
||||||
|
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert 0 <= n <= self.max_spec_num, f"Invalid number {n} to speculate"
|
||||||
|
|
||||||
|
# FIXME For compatibility with transformers 4.36.2 (versions before 4.38.0)
|
||||||
|
if input_ids.dim() == 1:
|
||||||
|
input_ids = input_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
if past_key_values is None:
|
||||||
|
past_key_values = self._past_key_values
|
||||||
|
|
||||||
|
logits = []
|
||||||
|
token_ids = []
|
||||||
|
|
||||||
|
for _ in range(n):
|
||||||
|
outputs = self._drafter_model(
|
||||||
|
input_ids,
|
||||||
|
return_dict=True,
|
||||||
|
use_cache=True,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
|
next_token_logits = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
|
# Skip logits_processor for drafter model
|
||||||
|
|
||||||
|
# Sample
|
||||||
|
if self.do_sample:
|
||||||
|
if self.sample_fn is not None:
|
||||||
|
probs = self.sample_fn(next_token_logits)
|
||||||
|
else:
|
||||||
|
probs = nn.functional.softmax(next_token_logits, dim=-1)
|
||||||
|
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||||
|
else:
|
||||||
|
next_token_ids = torch.argmax(next_token_logits, dim=-1)
|
||||||
|
|
||||||
|
logits.append(next_token_logits)
|
||||||
|
token_ids.append(next_token_ids)
|
||||||
|
if next_token_ids.item() == self._tokenizer.eos_token_id:
|
||||||
|
# TODO support bsz > 1
|
||||||
|
break
|
||||||
|
input_ids = next_token_ids[:, None]
|
||||||
|
past_key_values = outputs.past_key_values
|
||||||
|
|
||||||
|
speculated_length = len(token_ids) # TODO For now, only support bsz 1
|
||||||
|
logits = torch.concat(logits, dim=0)
|
||||||
|
token_ids = torch.concat(token_ids, dim=-1)
|
||||||
|
# update past_key_values
|
||||||
|
self._past_key_values = past_key_values
|
||||||
|
|
||||||
|
out = DrafterOutput(
|
||||||
|
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values
|
||||||
|
)
|
||||||
|
return out
|
|
@ -0,0 +1,29 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DrafterOutput:
|
||||||
|
"""
|
||||||
|
Dataclass for drafter model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
speculated_length (int): Speculated length of the output sequence
|
||||||
|
It is always less than or equal to spec_num during drafter's speculation process
|
||||||
|
logits (torch.FloatTensor): Logits of the output sequence
|
||||||
|
next_tokens (torch.Tensor): Next token ids
|
||||||
|
past_key_values (Optional[Tuple[Tuple[torch.FloatTensor]]]): Past key values of the output sequence
|
||||||
|
"""
|
||||||
|
|
||||||
|
speculated_length: int = None
|
||||||
|
logits: torch.FloatTensor = None
|
||||||
|
next_tokens: torch.Tensor = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
assert self.speculated_length is not None and self.speculated_length >= 0
|
||||||
|
if self.past_key_values is not None:
|
||||||
|
assert isinstance(self.past_key_values, tuple), "Past key values should be a tuple"
|
||||||
|
assert all([isinstance(past_key_value, tuple) for past_key_value in self.past_key_values])
|
|
@ -0,0 +1,41 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
|
||||||
|
|
||||||
|
from colossalai.inference.spec.drafter import Drafter
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
NUM_LAYERS = 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("spec_num", [5])
|
||||||
|
def test_drafter(spec_num: int):
|
||||||
|
torch.manual_seed(123)
|
||||||
|
|
||||||
|
device = get_current_device()
|
||||||
|
|
||||||
|
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
|
||||||
|
toy_config.pad_token_id = toy_config.eos_token_id
|
||||||
|
drafter_model = LlamaForCausalLM(toy_config)
|
||||||
|
drafter_model = drafter_model.eval().cuda()
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
|
|
||||||
|
drafter = Drafter(drafter_model, tokenizer, spec_num, device=device)
|
||||||
|
|
||||||
|
input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
|
||||||
|
out = drafter.speculate(input_ids, spec_num)
|
||||||
|
past_kv_length = input_ids.size(1) + spec_num - 1
|
||||||
|
|
||||||
|
assert out.speculated_length == spec_num
|
||||||
|
assert out.next_tokens.shape == (spec_num,)
|
||||||
|
assert out.logits.shape == (spec_num, len(tokenizer))
|
||||||
|
assert drafter._past_key_values[0][0].size(2) == out.past_key_values[0][0].size(2) == past_kv_length
|
||||||
|
|
||||||
|
reject_num = 3
|
||||||
|
assert reject_num <= spec_num
|
||||||
|
drafter.trim_kv_cache(reject_num)
|
||||||
|
assert drafter._past_key_values[0][0].size(2) == past_kv_length - reject_num
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_drafter(spec_num=5)
|
Loading…
Reference in New Issue