[Inference/SpecDec] Add Speculative Decoding Implementation (#5423)

* fix flash decoding mask during verification

* add spec-dec

* add test for spec-dec

* revise drafter init

* remove drafter sampling

* retire past kv in drafter

* (trivial) rename attrs

* (trivial) rename arg

* revise how we enable/disable spec-dec
feat/speculative-decoding
Yuanheng Zhao 2024-03-11 09:51:42 +08:00 committed by Yuanheng
parent 5a9b05f7b2
commit a37f82629d
11 changed files with 484 additions and 133 deletions

View File

@ -42,6 +42,9 @@ class BatchBucket:
self.device = device or get_current_device()
self.dtype = dtype
self._use_spec_dec = False
self._num_tokens_to_verify = None
self._current_batch_size = 0
self._sequences_dict = dict()
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
@ -88,6 +91,28 @@ class BatchBucket:
== torch.nonzero(self._block_tables[:, 0] >= 0).numel()
)
@property
def use_spec_dec(self) -> bool:
return self._use_spec_dec
@property
def num_tokens_to_verify(self) -> int:
assert self.use_spec_dec and self._num_tokens_to_verify is not None
return self._num_tokens_to_verify
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
"""Set batch bucket to use speculatvie decoding.
This will notify the adjust the lengths of inputs during modeling,
and let the main model verifies tokens in parallel.
"""
self._use_spec_dec = True
self._num_tokens_to_verify = num_tokens_to_verify
def reset_use_spec_dec(self) -> None:
"""Reset the usage of speculative decoding for the batch bucket"""
self._use_spec_dec = False
self._num_tokens_to_verify = None
def _make_compact(self) -> None:
# Clean and Compress the batch based on its sequences dict.
# Namely,compress sequences to the front and clean the seq lengths and block tables tensors.
@ -347,6 +372,19 @@ class BatchBucket:
seq.check_finish()
self._sequence_lengths[: self.current_batch_size] += 1
def revoke_batch_tokens(self, n: int) -> None:
"""Revoke the last n output tokens of the sequences in the batch
Args:
n (int): The number of output tokens to revoke from each sequence.
It does not count in the context tokens (input tokens).
"""
if n >= 1:
for seq_id, seq in self._sequences_dict.items():
assert seq.output_len >= n, "Revoking len exceeds the current output len of the sequence"
seq.output_token_id = seq.output_token_id[:-n]
self._sequence_lengths -= n
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
"""Clear all the sequences in the batch.
@ -401,6 +439,21 @@ class BatchBucket:
return True
return False
def get_1D_inputs_spec_dec(self, n: int) -> torch.Tensor:
# Used for main model verification in **Decoding Stage**
# `n` is the number of tokens to be verified,
# and so that prepare the last `n` tokens of each sequence as the inputs
assert len(self._sequences_dict) > 0, "No sequence in the batch"
assert all(
seq.output_len >= n for seq in self._sequences_dict.values()
), "Sequence output tokens must be greater than or equal to the number of tokens to be verified."
out_li = []
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id]
out_li.extend(seq.output_token_id[-n:])
return torch.tensor(out_li, dtype=torch.long, device=self.device)
# For compatibility
def get_1D_inputs(self) -> torch.Tensor:
assert len(self._sequences_dict) > 0, "No sequence in the batch"
@ -411,8 +464,6 @@ class BatchBucket:
seq.output_len == 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
out_li = []
num_tokens = torch.sum(self._sequence_lengths)
out = torch.empty([num_tokens], dtype=torch.long)
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id]
@ -420,6 +471,10 @@ class BatchBucket:
return torch.tensor(out_li, dtype=torch.long, device=self.device)
else:
# Assume decoding stage
if self.use_spec_dec:
# For Speculative Decoding
# the number of tokens to be verified in parallel plus the correct token in the last step
return self.get_1D_inputs_spec_dec(self.num_tokens_to_verify + 1)
assert all(
seq.output_len > 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch"

View File

@ -84,6 +84,8 @@ class InferenceConfig:
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
block_size (int): The number of blocks in a logical block, defaults to 16.
tp_size (int): Tensor parallel size, defaults to 1.
pp_size (int): Pipeline parallel size, defaults to 1.
@ -118,6 +120,10 @@ class InferenceConfig:
top_p: Optional[float] = None
min_p: Optional[float] = None
# speculative decoding configs
max_n_spec_tokens: int = 5
glimpse_large_kv: bool = False
# paged attention configs
block_size: int = 16

View File

@ -12,6 +12,7 @@ from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.spec import Drafter
from colossalai.inference.struct import Sequence
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
@ -52,19 +53,26 @@ class InferenceEngine:
verbose: bool = False,
model_policy: Policy = None,
) -> None:
assert inference_config, "Please provide inference_config."
assert tokenizer, "Please provide a tokenizer, either a defined one or str"
self.inference_config = inference_config
self.model_config = model.config
self.model = model
self.device = torch.device("cuda")
self.dtype = inference_config.dtype
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.generation_config = inference_config.to_generation_config(self.model_config)
self.high_precision = inference_config.high_precision
model = model.eval()
model = model.cuda()
model.to(self.dtype)
self._verify_args()
self.generation_config = inference_config.to_generation_config(self.model_config)
model.eval()
model = model.to(self.dtype)
model = model.to(self.device)
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = False
self.drafter_model = None
self.drafter = None
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
if model_policy is None:
if self.inference_config.pad_input:
@ -174,21 +182,18 @@ class InferenceEngine:
if self.verbose:
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
def _verify_config(self) -> None:
"""
Verify the input config
"""
def _verify_args(self) -> None:
"""Verify the input args"""
if not isinstance(self.inference_config, InferenceConfig):
raise TypeError("Invalid type of inference config provided.")
if not isinstance(self.model, nn.Module):
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
self.tokenizer, PreTrainedTokenizer
):
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
)
assert (
self.model.__class__.__name__ in _supported_models
), f"Model {self.model.__class__.__name__} is not supported."
if self.model.__class__.__name__ not in _supported_models:
raise ValueError(f"Model {self.model.__class__.__name__} is not supported.")
def _shardformer(
self,
@ -224,6 +229,138 @@ class InferenceEngine:
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model
def enable_spec_dec(self, drafter_model: nn.Module = None, n_spec_tokens: int = None) -> None:
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
Args:
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
If provided, the previous drafter and drafter model, if exist, will be overwritten.
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
```python
...
engine = InferenceEngine(model, tokenizer, inference_config)
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
engine.generate(...) # Speculative Decoding
engine.disable_spec_dec()
engine.generate(...) # Normal generation
engine.enable_spec_dec()
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
engine.clear_spec_dec()
```
"""
if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
if n_spec_tokens is not None:
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
self.n_spec_tokens = n_spec_tokens
if drafter_model is not None:
assert isinstance(drafter_model, nn.Module)
# overwrite the drafter, if exists
self.clear_spec_dec()
self.drafter_model = drafter_model
self.drafter = Drafter(
self.drafter_model,
self.tokenizer,
device=self.device,
dtype=self.dtype,
)
# using speculative decoding for subsequent generations
self.use_spec_dec = True
def disable_spec_dec(self) -> None:
"""Disable using speculative decoding for subsequent generations."""
# set back to the maximum number of tokens to speculate
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self.use_spec_dec = False
return
def clear_spec_dec(self) -> None:
"""Clear relatable structures of speculative decoding, if exist."""
if self.drafter_model or self.drafter:
self.drafter_model = None
self.drafter = None
torch.cuda.empty_cache()
self.use_spec_dec = False
return
def steps_spec_dec(self) -> List[Sequence]:
"""
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
with many steps of speculating by a drafter model as well as verifying by a main model.
Returns:
List[Sequence]: finished sequences generated by one step.
"""
batch = self.request_handler.schedule() # prefill batch
batch.set_use_spec_dec(self.n_spec_tokens) # set batch to use-spec-dec mode
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
input_ids = batch.get_1D_inputs() # bsz 1 for drafter model
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
drafter_out = self.drafter.speculate(input_ids, 1, None)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
# 2. Prefill main model (Verifier) - fill past kv cache for main model
logits = self.model(batch, self.k_cahce, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
# append new inputs to the batch, temporarily
batch.append_batch_tokens(next_tokens)
self.request_handler.allocate_batch_spec_dec(batch, 1)
already_allocated_kv_len = batch.seq_lengths[0].item()
input_ids = batch.get_1D_inputs_spec_dec(1)
batch.reset_use_spec_dec() # reset batch use-spec-dec mode
finished_sequences = self.request_handler.update()
while True:
# HACK Retrieve the running batch
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
batch = self.request_handler.running_bb # running batch
batch.set_use_spec_dec(self.n_spec_tokens)
# 3. Decoding - Drafter model speculates `n` tokens
drafter_out = self.drafter.speculate(input_ids, self.n_spec_tokens, drafter_past_key_values)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
for next_token_id_spec in next_token_ids_spec:
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
cur_length = batch.seq_lengths[0].item()
if already_allocated_kv_len < cur_length:
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
already_allocated_kv_len = cur_length
# 4. Decoding - Main model verifies `n` tokens in parallel
logits = self.model(batch, self.k_cahce, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
# 5. Compare and process the results
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
n_matches = self.n_spec_tokens if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
# revoke appended tokens for each Sequence in the current batch
batch.revoke_batch_tokens(self.n_spec_tokens - n_matches) # revoke drafted tokens
# append the last correct token generated by the main model
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
input_ids = batch.get_1D_inputs_spec_dec(1)
# trim past key values of the drafter model
drafter_past_key_values = Drafter.trim_kv_cache(drafter_past_key_values, self.n_spec_tokens - n_matches - 1)
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
finished_sequences = self.request_handler.update()
if len(finished_sequences) > 0:
break
batch.reset_use_spec_dec()
return finished_sequences
def generate(
self,
prompts: List[str] = None,
@ -246,7 +383,6 @@ class InferenceEngine:
List[str]: Inference result returned by one generation.
"""
with torch.inference_mode():
self.generation_config = generation_config
if prompts is not None or prompts_token_ids is not None:
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
@ -257,8 +393,13 @@ class InferenceEngine:
if generation_config is not None:
self.generation_config = generation_config
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.step()
if self.use_spec_dec:
assert self.drafter is not None, "Drafter Model is not initialized."
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.steps_spec_dec()
else:
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.step()
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
@ -428,7 +569,8 @@ class InferenceEngine:
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input:
logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update()

View File

@ -134,8 +134,12 @@ class RequestHandler:
if fd_inter_tensor._tensors_initialized:
fd_inter_tensor._reset()
# For Spec-Dec, process the speculated tokens plus the token in the last step for each seq
max_n_tokens = self.max_batch_size
max_n_tokens *= self.inference_config.max_n_spec_tokens + 1
fd_inter_tensor.initialize(
max_batch_size=self.max_batch_size,
max_batch_size=max_n_tokens,
num_attn_heads=model_config.num_attention_heads,
kv_max_split_num=kv_max_split_num,
head_dim=head_dim,
@ -230,6 +234,13 @@ class RequestHandler:
return self.running_bb
def allocate_batch_spec_dec(self, batch: BatchBucket, n: int):
assert batch.use_spec_dec
if n > 0:
self.cache_manager.allocate_n_tokens_from_block_tables(
batch.block_tables, batch.seq_lengths, batch.current_batch_size, n=n
)
def add_sequence(self, req: Sequence):
"""
Add the request to waiting list.
@ -282,13 +293,21 @@ class RequestHandler:
return sample_tokens
def mark_finished(self, sequence: Sequence, generation_config: GenerationConfig):
def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig):
if (
sequence.output_token_id[-1] == generation_config.eos_id
or sequence.output_len >= generation_config.max_output_len
sequence.output_token_id[-1] == generation_config.eos_token_id
or sequence.output_len >= generation_config.max_length
):
sequence.mark_finished()
def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig):
for seq in batch.seqs_li:
if (
seq.output_token_id[-1] == generation_config.eos_token_id
or seq.output_len >= generation_config.max_length
):
seq.mark_finished()
def check_unfinished_seqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty()
@ -309,9 +328,20 @@ class RequestHandler:
# sample the next tokens
sample_tokens = self._sample(probs, logprobs, generation_config)
return sample_tokens
def append_next_tokens(self, sample_tokens: torch.Tensor):
assert sample_tokens.dim() == 1
n_elements = sample_tokens.size(0)
if not self.prefill_bb.is_empty:
assert (
self.prefill_bb.current_batch_size == n_elements
), f"Incompatible size: {n_elements} tokens to append while prefill batch size {self.prefill_bb.current_batch_size}"
self.prefill_bb.append_batch_tokens(sample_tokens)
else:
assert (
self.running_bb.current_batch_size == n_elements
), f"Incompatible size: {n_elements} tokens to append while running batch size {self.running_bb.current_batch_size}"
self.running_bb.append_batch_tokens(sample_tokens)
def update(self):

View File

@ -349,6 +349,26 @@ class KVCacheManager:
return seqs_to_recycle
def allocate_n_tokens_from_block_tables(
self,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
bsz: int,
n: int,
) -> List[int]:
"""Allocate logical cache blocks for `n` new tokens for a batch of sequences during decoding stage."""
assert block_tables.dim() == 2
assert context_lens.dim() == 1
bsz = block_tables.size(0) if bsz is None else bsz
assert bsz == 1, "Support bsz 1 for now" # TODO support bsz > 1
seqs_to_recycle = []
for i in range(n):
seqs_to_recycle += self.allocate_tokens_from_block_tables(block_tables, context_lens - n + i + 1, bsz)
return seqs_to_recycle
def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int:
"""Allocate space asked on a single block in the block table, specified by the provided position id,
and updates the provided block table with the allocated block.
@ -420,9 +440,7 @@ class KVCacheManager:
Returns:
The remaining space required to be allocated (in other blocks).
"""
assert (
block.available_space > 0
), "Tried to allocate some space but found no available space left in chosen block."
assert block.available_space > 0, f"Found no available space left in the chosen block {block}."
space_to_allocate = min(block.available_space, space_asked)
block.allocate(space_to_allocate)
return space_asked - space_to_allocate

View File

@ -18,6 +18,7 @@ from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding,
flash_decoding_attention,
get_xine_cache,
@ -84,9 +85,9 @@ def llama_model_forward(
"""This function will replace the forward function of LlamaModel.
Args:
batch (BatchInfo): It stores the necessary input information for this inference.
k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache.
v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache.
batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""
block_tables = inputmetadata.block_tables
@ -101,7 +102,25 @@ def llama_model_forward(
use_cuda_kernel = False
hidden_states = self.embed_tokens(input_tokens_ids)
if use_cuda_kernel:
cu_seqlens = None
# NOTE (yuanheng-zhao): we do not use cuda kernels for speculative-decoding for now
if inputmetadata.use_spec_dec:
# For speculative-decoding Prefill and Verifying Stage
if inputmetadata.is_prompts:
# output tensor shape is the same as normal Prefill Stage
o_tensor_size = (sequence_lengths.sum().item(), inputmetadata.num_heads * inputmetadata.head_dim)
rotary_indexes = [torch.arange(0, length) for length in sequence_lengths]
else:
# the number of tokens to be verified in parallel plus the correct token in the last step
n_tokens = inputmetadata.num_tokens_to_verify + 1
assert n_tokens == hidden_states.size(0)
o_tensor_size = (batch_size * n_tokens, inputmetadata.num_heads * inputmetadata.head_dim)
rotary_indexes = [(length - n_tokens + i).view(-1) for i in range(n_tokens) for length in sequence_lengths]
rotary_indexes = torch.cat(rotary_indexes, dim=-1)
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
elif use_cuda_kernel:
if inputmetadata != torch.float32 and use_flash_attn2:
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
@ -113,14 +132,22 @@ def llama_model_forward(
self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts
)
cos_sin = (cos, sin)
else:
cu_seqlens = None
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
# TODO (yuanheng-zhao): revise the logic here
# if batch.is_prompts:
# output_tensor = torch.zeros(
# (sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
# )
# else:
# output_tensor = torch.zeros(
# (batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
# )
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
norm_output = torch.empty_like(hidden_states)
tokens_to_verify = inputmetadata.num_tokens_to_verify if inputmetadata.use_spec_dec else None
residual = None
for layer_id, decoder_layer in enumerate(self.layers):
@ -131,6 +158,8 @@ def llama_model_forward(
k_cache=k_caches[layer_id],
v_cache=v_caches[layer_id],
is_prompts=inputmetadata.is_prompts,
is_verifier=inputmetadata.use_spec_dec,
tokens_to_verify=tokens_to_verify,
sequence_lengths=sequence_lengths,
cos_sin=cos_sin,
fd_inter_tensor=inputmetadata.fd_inter_tensor,
@ -144,9 +173,9 @@ def llama_model_forward(
)
if inputmetadata.is_prompts:
last_token_indexs = sequence_lengths.cumsum(dim=-1)
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
residual = residual[last_token_indexs - 1].contiguous()
seq_len_cumsum = sequence_lengths.cumsum(dim=0)
hidden_states = hidden_states[seq_len_cumsum - 1].contiguous()
residual = residual[seq_len_cumsum - 1].contiguous()
norm_output = torch.empty_like(hidden_states)
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
@ -164,6 +193,8 @@ def llama_decoder_layer_forward(
cos_sin: Tuple[torch.Tensor],
fd_inter_tensor: FDIntermTensors,
is_prompts: bool = True,
is_verifier: bool = False,
tokens_to_verify: int = None,
kv_seq_len: int = 0,
output_tensor: torch.Tensor = None,
norm_output: torch.Tensor = None,
@ -202,6 +233,9 @@ def llama_decoder_layer_forward(
block_tables=block_tables,
k_cache=k_cache,
v_cache=v_cache,
is_prompts=is_prompts,
is_verifier=is_verifier,
tokens_to_verify=tokens_to_verify,
sequence_lengths=sequence_lengths,
cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor,
@ -312,6 +346,8 @@ class NopadLlamaAttention(LlamaAttention):
cos_sin: Tuple[torch.Tensor],
fd_inter_tensor: FDIntermTensors,
is_prompts: bool = True,
is_verifier: bool = False,
tokens_to_verify: int = None,
kv_seq_len: int = 0,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
@ -355,7 +391,7 @@ class NopadLlamaAttention(LlamaAttention):
block_size = k_cache.size(-2)
if is_prompts:
if use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
# flash attn 2 currently only supports FP16/BF16.
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
inference_ops.context_kv_cache_memcpy(
@ -405,17 +441,27 @@ class NopadLlamaAttention(LlamaAttention):
high_precision,
)
else:
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
q_len = tokens_to_verify + 1 if is_verifier else 1
if is_verifier:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
copy_k_to_blocked_cache(
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
copy_k_to_blocked_cache(
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
else:
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
@ -428,8 +474,10 @@ class NopadLlamaAttention(LlamaAttention):
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
q_len=q_len,
)
attn_output = attn_output.view(-1, self.hidden_size)
attn_output = torch.mm(attn_output, self.o_proj_weight)
return attn_output

View File

@ -15,93 +15,75 @@ class Drafter:
Args:
model (nn.Module): The drafter model.
tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the drafter model.
max_spec_num (int): The maximum number of tokens to speculate.
device (torch.device): The device for the drafter model.
"""
def __init__(
self, model: nn.Module, tokenizer: PreTrainedTokenizer, max_spec_num: int, device: torch.device = None
self,
model: nn.Module,
tokenizer: PreTrainedTokenizer,
device: torch.device = None,
dtype: torch.dtype = torch.float16,
):
self._drafter_model = model
self._tokenizer = tokenizer
self.max_spec_num = max_spec_num
self.do_sample = False
self.sample_fn = None
self._device = device or get_current_device()
self._past_key_values = None
@property
def past_key_values(self) -> Optional[Tuple[Tuple[torch.FloatTensor]]]:
return self._past_key_values
# Debug usage for now
@property
def past_key_values_shape(self):
if self._past_key_values is None:
return []
return self._past_key_values[0][0].shape
self._dtype = dtype
self._drafter_model = model.to(self._device)
self._drafter_model = model.to(self._dtype)
self._drafter_model.eval()
def get_model(self) -> nn.Module:
return self._drafter_model
def reset_sample_method(self, sample_fn: callable) -> None:
self.do_sample = True
self.sample_fn = sample_fn
@staticmethod
def trim_kv_cache(
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], invalid_token_num: int
) -> Tuple[Tuple[torch.FloatTensor]]:
"""Trim the last `invalid_token_num` kv caches.
def clear_sample_method(self) -> None:
self.do_sample = False
self.sample_fn = None
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values with shape
num_layers x 2 x (bsz x num_heads x seq_len x head_dim)
invalid_token_num (int): The number of invalid tokens to trim.
"""
if past_key_values is None or invalid_token_num < 1:
return past_key_values
def reset_max_spec_num(self, n: int) -> None:
assert isinstance(n, int) and n > 1
self.max_spec_num = n
def reset_past_key_values(self, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None) -> None:
self._past_key_values = past_key_values
def trim_kv_cache(self, invalid_token_num) -> Tuple[Tuple[torch.FloatTensor]]:
# Tuple of kv cache tensors: num_layers x 2 x (bsz x num_heads x seq_len x head_dim)
# Trim the last `invalid_token_num` kv caches
# The verifier (main model) might reject `invalid_token_num` tokens,
# and so that we have to trim the invalid tokens for the kv cache of the drafter model.
assert self._past_key_values is not None
trimmed_past_key_values = []
for layer_idx in range(len(self._past_key_values)):
past_key_value = self._past_key_values[layer_idx]
for layer_idx in range(len(past_key_values)):
past_key_value = past_key_values[layer_idx]
trimmed_past_key_values.append(
(
past_key_value[0][:, :, :-invalid_token_num, :],
past_key_value[1][:, :, :-invalid_token_num, :],
)
)
self._past_key_values = tuple(trimmed_past_key_values)
return self._past_key_values
past_key_values = tuple(trimmed_past_key_values)
return past_key_values
@torch.inference_mode()
def speculate(
self, input_ids: torch.Tensor, n: int, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None
self,
input_ids: torch.Tensor,
n_spec_tokens: int,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
) -> DrafterOutput:
"""Generate n tokens using the drafter model.
"""Generate n_spec_tokens tokens using the drafter model.
Args:
input_ids (torch.Tensor): Input token ids.
n (int): Number of tokens to speculate.
n_spec_tokens (int): Number of tokens to speculate.
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence.
"""
assert n_spec_tokens >= 1, f"Invalid number {n_spec_tokens} to speculate"
assert 0 <= n <= self.max_spec_num, f"Invalid number {n} to speculate"
# FIXME For compatibility with transformers 4.36.2 (versions before 4.38.0)
# For compatibility with transformers of versions before 4.38.0
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
if past_key_values is None:
past_key_values = self._past_key_values
logits = []
token_ids = []
for _ in range(n):
for _ in range(n_spec_tokens):
outputs = self._drafter_model(
input_ids,
return_dict=True,
@ -110,17 +92,10 @@ class Drafter:
)
next_token_logits = outputs.logits[:, -1, :]
# Skip logits_processor for drafter model
# Sample
if self.do_sample:
if self.sample_fn is not None:
probs = self.sample_fn(next_token_logits)
else:
probs = nn.functional.softmax(next_token_logits, dim=-1)
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_token_ids = torch.argmax(next_token_logits, dim=-1)
# NOTE Only use greedy search for speculating.
# As the drafter model usually has only a few layers with few parameters,
# introducing sampling will make the speculation unstable and lead to worse performance.
next_token_ids = torch.argmax(next_token_logits, dim=-1)
logits.append(next_token_logits)
token_ids.append(next_token_ids)
@ -133,8 +108,6 @@ class Drafter:
speculated_length = len(token_ids) # TODO For now, only support bsz 1
logits = torch.concat(logits, dim=0)
token_ids = torch.concat(token_ids, dim=-1)
# update past_key_values
self._past_key_values = past_key_values
out = DrafterOutput(
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values

View File

@ -44,6 +44,7 @@ def _flash_decoding_fwd_kernel(
cur_seq_idx = cur_token_idx // q_len
if cur_seq_idx >= batch_size:
return
cur_token_off = (cur_token_idx % q_len) - q_len + 1
cur_head_idx = tl.program_id(1)
block_start_kv = tl.program_id(2) # for splitting k/v
@ -52,7 +53,8 @@ def _flash_decoding_fwd_kernel(
# and then support calculating multiple kv cache blocks on an instance
tl.static_assert(BLOCK_KV == BLOCK_SIZE)
# get the current (kv) sequence length
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
# cur_token_off is used as a "mask" here for spec-dec during verification process
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
return
@ -150,7 +152,9 @@ def _flash_decoding_fwd_reduce_kernel(
return
cur_head_idx = tl.program_id(1)
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
# cur_token_off is used as a "mask" here for spec-dec during verification process
cur_token_off = (cur_token_idx % q_len) - q_len + 1
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
offsets_dmodel = tl.arange(0, HEAD_DIM)
# NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have

View File

@ -2,10 +2,15 @@ import pytest
import torch
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import GenerationConfig, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.spec.drafter import Drafter
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
NUM_LAYERS = 2
MAX_LEN = 100
@pytest.mark.parametrize("spec_num", [5])
@ -14,13 +19,13 @@ def test_drafter(spec_num: int):
device = get_current_device()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
toy_config.pad_token_id = toy_config.eos_token_id
toy_config.pad_token_id = tokenizer.eos_token_id
drafter_model = LlamaForCausalLM(toy_config)
drafter_model = drafter_model.eval().cuda()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
drafter = Drafter(drafter_model, tokenizer, spec_num, device=device)
drafter = Drafter(drafter_model, tokenizer, device=device)
input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
out = drafter.speculate(input_ids, spec_num)
@ -29,13 +34,75 @@ def test_drafter(spec_num: int):
assert out.speculated_length == spec_num
assert out.next_tokens.shape == (spec_num,)
assert out.logits.shape == (spec_num, len(tokenizer))
assert drafter._past_key_values[0][0].size(2) == out.past_key_values[0][0].size(2) == past_kv_length
assert out.past_key_values[0][0].size(2) == past_kv_length
reject_num = 3
assert reject_num <= spec_num
drafter.trim_kv_cache(reject_num)
assert drafter._past_key_values[0][0].size(2) == past_kv_length - reject_num
reject_num = max(0, spec_num - 1)
trimmed_past_key_values = drafter.trim_kv_cache(out.past_key_values, reject_num)
assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num
def check_sd():
torch.manual_seed(123)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
# Dummy configs for testing
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
toy_config.pad_token_id = tokenizer.eos_token_id
drafter_model = LlamaForCausalLM(toy_config)
drafter_model = drafter_model.eval().cuda()
large_config = LlamaConfig(
hidden_size=4096,
intermediate_size=11008,
num_attention_heads=32,
num_hidden_layers=8,
num_key_value_heads=32,
max_position_embeddings=2048,
)
large_config.pad_token_id = tokenizer.eos_token_id
main_model = LlamaForCausalLM(large_config)
inference_config = InferenceConfig(
dtype="fp16",
micro_batch_size=1,
max_batch_size=1,
max_input_len=128,
max_output_len=128,
prefill_ratio=1.2,
block_size=16,
)
engine = InferenceEngine(main_model, tokenizer, inference_config)
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda")
generation_config = GenerationConfig(
pad_token_id=tokenizer.eos_token_id,
max_length=MAX_LEN,
eos_token_id=tokenizer.eos_token_id,
)
out, out_token_ids = engine.generate(
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
)
engine.disable_spec_dec()
engine.clear_spec_dec()
assert not engine.use_spec_dec
assert engine.drafter is None and engine.drafter_model is None
assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == MAX_LEN
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_sd()
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_spec_dec():
spawn(run_dist, nprocs=1)
if __name__ == "__main__":
test_drafter(spec_num=5)
test_spec_dec()

View File

@ -19,12 +19,19 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)
def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, q_len: int, kv_len: int, device="cuda"):
def create_attention_mask(kv_lengths: torch.Tensor, bsz: int, q_len: int, kv_len: int, device="cuda"):
assert q_len <= kv_len
causal_mask = torch.full((q_len, q_len), fill_value=float("-inf"), device=device).triu(diagonal=1)
padding_mask = torch.zeros((bsz, 1, q_len, kv_len), dtype=torch.float32, device=device)
for i in range(bsz):
cur_seq_len = kv_lengths[i].item()
assert cur_seq_len <= kv_len
padding_mask[i, :, :, : kv_len - cur_seq_len] = float("-inf")
padding_mask[:, :, -q_len:, -q_len:] += causal_mask
return padding_mask
@ -56,11 +63,13 @@ def torch_attn_ref(
attn_scores = qk / (head_dim**0.5)
assert attn_scores.shape == (bsz, num_heads, q_len, kv_len), "Invalid shape of attention scores"
# for left-side padding
if attention_mask.size() != (bsz, 1, q_len, kv_len):
raise ValueError(f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}")
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}"
)
attn_scores = attn_scores + attention_mask
attn_scores = attn_scores + attention_mask
attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype)
out = torch.matmul(attn_weights, v)
if out.size() != (bsz, num_heads, q_len, head_dim):

View File

@ -6,8 +6,8 @@ from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import (
convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v2,
prepare_padding_mask,
torch_attn_ref,
)
@ -91,9 +91,9 @@ def test_flash_decoding(
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_kv_len_in_b)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_kv_len_in_b)
torch_padding_mask = prepare_padding_mask(kv_lengths, bsz, q_len, max_kv_len_in_b, q.device)
attention_mask = create_attention_mask(kv_lengths, bsz, q_len, max_kv_len_in_b, q.device)
out_torch = torch_attn_ref(
q, k_torch, v_torch, torch_padding_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
q, k_torch, v_torch, attention_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
)
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
@ -138,6 +138,5 @@ def test_flash_decoding(
assert out_torch.shape == out_triton.shape
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
if __name__ == "__main__":
test_flash_decoding(16, 32, 32, 16, 1, True)