[Inference/SpecDec] Add Speculative Decoding Implementation (#5423)

* fix flash decoding mask during verification

* add spec-dec

* add test for spec-dec

* revise drafter init

* remove drafter sampling

* retire past kv in drafter

* (trivial) rename attrs

* (trivial) rename arg

* revise how we enable/disable spec-dec
feat/speculative-decoding
Yuanheng Zhao 2024-03-11 09:51:42 +08:00 committed by Yuanheng
parent 5a9b05f7b2
commit a37f82629d
11 changed files with 484 additions and 133 deletions

View File

@ -42,6 +42,9 @@ class BatchBucket:
self.device = device or get_current_device() self.device = device or get_current_device()
self.dtype = dtype self.dtype = dtype
self._use_spec_dec = False
self._num_tokens_to_verify = None
self._current_batch_size = 0 self._current_batch_size = 0
self._sequences_dict = dict() self._sequences_dict = dict()
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size) self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
@ -88,6 +91,28 @@ class BatchBucket:
== torch.nonzero(self._block_tables[:, 0] >= 0).numel() == torch.nonzero(self._block_tables[:, 0] >= 0).numel()
) )
@property
def use_spec_dec(self) -> bool:
return self._use_spec_dec
@property
def num_tokens_to_verify(self) -> int:
assert self.use_spec_dec and self._num_tokens_to_verify is not None
return self._num_tokens_to_verify
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
"""Set batch bucket to use speculatvie decoding.
This will notify the adjust the lengths of inputs during modeling,
and let the main model verifies tokens in parallel.
"""
self._use_spec_dec = True
self._num_tokens_to_verify = num_tokens_to_verify
def reset_use_spec_dec(self) -> None:
"""Reset the usage of speculative decoding for the batch bucket"""
self._use_spec_dec = False
self._num_tokens_to_verify = None
def _make_compact(self) -> None: def _make_compact(self) -> None:
# Clean and Compress the batch based on its sequences dict. # Clean and Compress the batch based on its sequences dict.
# Namely,compress sequences to the front and clean the seq lengths and block tables tensors. # Namely,compress sequences to the front and clean the seq lengths and block tables tensors.
@ -347,6 +372,19 @@ class BatchBucket:
seq.check_finish() seq.check_finish()
self._sequence_lengths[: self.current_batch_size] += 1 self._sequence_lengths[: self.current_batch_size] += 1
def revoke_batch_tokens(self, n: int) -> None:
"""Revoke the last n output tokens of the sequences in the batch
Args:
n (int): The number of output tokens to revoke from each sequence.
It does not count in the context tokens (input tokens).
"""
if n >= 1:
for seq_id, seq in self._sequences_dict.items():
assert seq.output_len >= n, "Revoking len exceeds the current output len of the sequence"
seq.output_token_id = seq.output_token_id[:-n]
self._sequence_lengths -= n
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]: def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
"""Clear all the sequences in the batch. """Clear all the sequences in the batch.
@ -401,6 +439,21 @@ class BatchBucket:
return True return True
return False return False
def get_1D_inputs_spec_dec(self, n: int) -> torch.Tensor:
# Used for main model verification in **Decoding Stage**
# `n` is the number of tokens to be verified,
# and so that prepare the last `n` tokens of each sequence as the inputs
assert len(self._sequences_dict) > 0, "No sequence in the batch"
assert all(
seq.output_len >= n for seq in self._sequences_dict.values()
), "Sequence output tokens must be greater than or equal to the number of tokens to be verified."
out_li = []
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id]
out_li.extend(seq.output_token_id[-n:])
return torch.tensor(out_li, dtype=torch.long, device=self.device)
# For compatibility # For compatibility
def get_1D_inputs(self) -> torch.Tensor: def get_1D_inputs(self) -> torch.Tensor:
assert len(self._sequences_dict) > 0, "No sequence in the batch" assert len(self._sequences_dict) > 0, "No sequence in the batch"
@ -411,8 +464,6 @@ class BatchBucket:
seq.output_len == 0 for seq in self._sequences_dict.values() seq.output_len == 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch" ), "Sequence stage (Prefill/Decoding) must be the same in the batch"
out_li = [] out_li = []
num_tokens = torch.sum(self._sequence_lengths)
out = torch.empty([num_tokens], dtype=torch.long)
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x]) seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids: for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id] seq: Sequence = self._sequences_dict[seq_id]
@ -420,6 +471,10 @@ class BatchBucket:
return torch.tensor(out_li, dtype=torch.long, device=self.device) return torch.tensor(out_li, dtype=torch.long, device=self.device)
else: else:
# Assume decoding stage # Assume decoding stage
if self.use_spec_dec:
# For Speculative Decoding
# the number of tokens to be verified in parallel plus the correct token in the last step
return self.get_1D_inputs_spec_dec(self.num_tokens_to_verify + 1)
assert all( assert all(
seq.output_len > 0 for seq in self._sequences_dict.values() seq.output_len > 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch" ), "Sequence stage (Prefill/Decoding) must be the same in the batch"

View File

@ -84,6 +84,8 @@ class InferenceConfig:
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None. top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None. top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None. min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
block_size (int): The number of blocks in a logical block, defaults to 16. block_size (int): The number of blocks in a logical block, defaults to 16.
tp_size (int): Tensor parallel size, defaults to 1. tp_size (int): Tensor parallel size, defaults to 1.
pp_size (int): Pipeline parallel size, defaults to 1. pp_size (int): Pipeline parallel size, defaults to 1.
@ -118,6 +120,10 @@ class InferenceConfig:
top_p: Optional[float] = None top_p: Optional[float] = None
min_p: Optional[float] = None min_p: Optional[float] = None
# speculative decoding configs
max_n_spec_tokens: int = 5
glimpse_large_kv: bool = False
# paged attention configs # paged attention configs
block_size: int = 16 block_size: int = 16

View File

@ -12,6 +12,7 @@ from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.graph_runner import CUDAGraphRunner from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.spec import Drafter
from colossalai.inference.struct import Sequence from colossalai.inference.struct import Sequence
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
@ -52,19 +53,26 @@ class InferenceEngine:
verbose: bool = False, verbose: bool = False,
model_policy: Policy = None, model_policy: Policy = None,
) -> None: ) -> None:
assert inference_config, "Please provide inference_config."
assert tokenizer, "Please provide a tokenizer, either a defined one or str"
self.inference_config = inference_config self.inference_config = inference_config
self.model_config = model.config self.model_config = model.config
self.model = model
self.device = torch.device("cuda") self.device = torch.device("cuda")
self.dtype = inference_config.dtype self.dtype = inference_config.dtype
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token = self.tokenizer.eos_token
self.generation_config = inference_config.to_generation_config(self.model_config)
self.high_precision = inference_config.high_precision self.high_precision = inference_config.high_precision
model = model.eval() self._verify_args()
model = model.cuda()
model.to(self.dtype) self.generation_config = inference_config.to_generation_config(self.model_config)
model.eval()
model = model.to(self.dtype)
model = model.to(self.device)
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = False
self.drafter_model = None
self.drafter = None
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
if model_policy is None: if model_policy is None:
if self.inference_config.pad_input: if self.inference_config.pad_input:
@ -174,21 +182,18 @@ class InferenceEngine:
if self.verbose: if self.verbose:
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s") self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
def _verify_config(self) -> None: def _verify_args(self) -> None:
""" """Verify the input args"""
Verify the input config if not isinstance(self.inference_config, InferenceConfig):
""" raise TypeError("Invalid type of inference config provided.")
if not isinstance(self.model, nn.Module): if not isinstance(self.model, nn.Module):
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}") raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance( if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
self.tokenizer, PreTrainedTokenizer
):
raise TypeError( raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
) )
assert ( if self.model.__class__.__name__ not in _supported_models:
self.model.__class__.__name__ in _supported_models raise ValueError(f"Model {self.model.__class__.__name__} is not supported.")
), f"Model {self.model.__class__.__name__} is not supported."
def _shardformer( def _shardformer(
self, self,
@ -224,6 +229,138 @@ class InferenceEngine:
shard_model, _ = shardformer.optimize(model, model_policy) shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model return shard_model
def enable_spec_dec(self, drafter_model: nn.Module = None, n_spec_tokens: int = None) -> None:
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
Args:
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
If provided, the previous drafter and drafter model, if exist, will be overwritten.
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
```python
...
engine = InferenceEngine(model, tokenizer, inference_config)
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
engine.generate(...) # Speculative Decoding
engine.disable_spec_dec()
engine.generate(...) # Normal generation
engine.enable_spec_dec()
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
engine.clear_spec_dec()
```
"""
if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
if n_spec_tokens is not None:
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
self.n_spec_tokens = n_spec_tokens
if drafter_model is not None:
assert isinstance(drafter_model, nn.Module)
# overwrite the drafter, if exists
self.clear_spec_dec()
self.drafter_model = drafter_model
self.drafter = Drafter(
self.drafter_model,
self.tokenizer,
device=self.device,
dtype=self.dtype,
)
# using speculative decoding for subsequent generations
self.use_spec_dec = True
def disable_spec_dec(self) -> None:
"""Disable using speculative decoding for subsequent generations."""
# set back to the maximum number of tokens to speculate
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self.use_spec_dec = False
return
def clear_spec_dec(self) -> None:
"""Clear relatable structures of speculative decoding, if exist."""
if self.drafter_model or self.drafter:
self.drafter_model = None
self.drafter = None
torch.cuda.empty_cache()
self.use_spec_dec = False
return
def steps_spec_dec(self) -> List[Sequence]:
"""
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
with many steps of speculating by a drafter model as well as verifying by a main model.
Returns:
List[Sequence]: finished sequences generated by one step.
"""
batch = self.request_handler.schedule() # prefill batch
batch.set_use_spec_dec(self.n_spec_tokens) # set batch to use-spec-dec mode
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
input_ids = batch.get_1D_inputs() # bsz 1 for drafter model
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
drafter_out = self.drafter.speculate(input_ids, 1, None)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
# 2. Prefill main model (Verifier) - fill past kv cache for main model
logits = self.model(batch, self.k_cahce, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
# append new inputs to the batch, temporarily
batch.append_batch_tokens(next_tokens)
self.request_handler.allocate_batch_spec_dec(batch, 1)
already_allocated_kv_len = batch.seq_lengths[0].item()
input_ids = batch.get_1D_inputs_spec_dec(1)
batch.reset_use_spec_dec() # reset batch use-spec-dec mode
finished_sequences = self.request_handler.update()
while True:
# HACK Retrieve the running batch
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
batch = self.request_handler.running_bb # running batch
batch.set_use_spec_dec(self.n_spec_tokens)
# 3. Decoding - Drafter model speculates `n` tokens
drafter_out = self.drafter.speculate(input_ids, self.n_spec_tokens, drafter_past_key_values)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
for next_token_id_spec in next_token_ids_spec:
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
cur_length = batch.seq_lengths[0].item()
if already_allocated_kv_len < cur_length:
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
already_allocated_kv_len = cur_length
# 4. Decoding - Main model verifies `n` tokens in parallel
logits = self.model(batch, self.k_cahce, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
# 5. Compare and process the results
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
n_matches = self.n_spec_tokens if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
# revoke appended tokens for each Sequence in the current batch
batch.revoke_batch_tokens(self.n_spec_tokens - n_matches) # revoke drafted tokens
# append the last correct token generated by the main model
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
input_ids = batch.get_1D_inputs_spec_dec(1)
# trim past key values of the drafter model
drafter_past_key_values = Drafter.trim_kv_cache(drafter_past_key_values, self.n_spec_tokens - n_matches - 1)
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
finished_sequences = self.request_handler.update()
if len(finished_sequences) > 0:
break
batch.reset_use_spec_dec()
return finished_sequences
def generate( def generate(
self, self,
prompts: List[str] = None, prompts: List[str] = None,
@ -246,7 +383,6 @@ class InferenceEngine:
List[str]: Inference result returned by one generation. List[str]: Inference result returned by one generation.
""" """
with torch.inference_mode(): with torch.inference_mode():
self.generation_config = generation_config
if prompts is not None or prompts_token_ids is not None: if prompts is not None or prompts_token_ids is not None:
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids) self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
@ -257,6 +393,11 @@ class InferenceEngine:
if generation_config is not None: if generation_config is not None:
self.generation_config = generation_config self.generation_config = generation_config
if self.use_spec_dec:
assert self.drafter is not None, "Drafter Model is not initialized."
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.steps_spec_dec()
else:
while self.request_handler.check_unfinished_seqs(): while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.step() output_seqs_list += self.step()
@ -428,7 +569,8 @@ class InferenceEngine:
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input: if self.inference_config.pad_input:
logits = logits[:, -1, :] logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits) next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update() finished_sequences = self.request_handler.update()

View File

@ -134,8 +134,12 @@ class RequestHandler:
if fd_inter_tensor._tensors_initialized: if fd_inter_tensor._tensors_initialized:
fd_inter_tensor._reset() fd_inter_tensor._reset()
# For Spec-Dec, process the speculated tokens plus the token in the last step for each seq
max_n_tokens = self.max_batch_size
max_n_tokens *= self.inference_config.max_n_spec_tokens + 1
fd_inter_tensor.initialize( fd_inter_tensor.initialize(
max_batch_size=self.max_batch_size, max_batch_size=max_n_tokens,
num_attn_heads=model_config.num_attention_heads, num_attn_heads=model_config.num_attention_heads,
kv_max_split_num=kv_max_split_num, kv_max_split_num=kv_max_split_num,
head_dim=head_dim, head_dim=head_dim,
@ -230,6 +234,13 @@ class RequestHandler:
return self.running_bb return self.running_bb
def allocate_batch_spec_dec(self, batch: BatchBucket, n: int):
assert batch.use_spec_dec
if n > 0:
self.cache_manager.allocate_n_tokens_from_block_tables(
batch.block_tables, batch.seq_lengths, batch.current_batch_size, n=n
)
def add_sequence(self, req: Sequence): def add_sequence(self, req: Sequence):
""" """
Add the request to waiting list. Add the request to waiting list.
@ -282,13 +293,21 @@ class RequestHandler:
return sample_tokens return sample_tokens
def mark_finished(self, sequence: Sequence, generation_config: GenerationConfig): def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig):
if ( if (
sequence.output_token_id[-1] == generation_config.eos_id sequence.output_token_id[-1] == generation_config.eos_token_id
or sequence.output_len >= generation_config.max_output_len or sequence.output_len >= generation_config.max_length
): ):
sequence.mark_finished() sequence.mark_finished()
def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig):
for seq in batch.seqs_li:
if (
seq.output_token_id[-1] == generation_config.eos_token_id
or seq.output_len >= generation_config.max_length
):
seq.mark_finished()
def check_unfinished_seqs(self) -> bool: def check_unfinished_seqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty() return self._has_waiting() or not self.running_list.is_empty()
@ -309,9 +328,20 @@ class RequestHandler:
# sample the next tokens # sample the next tokens
sample_tokens = self._sample(probs, logprobs, generation_config) sample_tokens = self._sample(probs, logprobs, generation_config)
return sample_tokens
def append_next_tokens(self, sample_tokens: torch.Tensor):
assert sample_tokens.dim() == 1
n_elements = sample_tokens.size(0)
if not self.prefill_bb.is_empty: if not self.prefill_bb.is_empty:
assert (
self.prefill_bb.current_batch_size == n_elements
), f"Incompatible size: {n_elements} tokens to append while prefill batch size {self.prefill_bb.current_batch_size}"
self.prefill_bb.append_batch_tokens(sample_tokens) self.prefill_bb.append_batch_tokens(sample_tokens)
else: else:
assert (
self.running_bb.current_batch_size == n_elements
), f"Incompatible size: {n_elements} tokens to append while running batch size {self.running_bb.current_batch_size}"
self.running_bb.append_batch_tokens(sample_tokens) self.running_bb.append_batch_tokens(sample_tokens)
def update(self): def update(self):

View File

@ -349,6 +349,26 @@ class KVCacheManager:
return seqs_to_recycle return seqs_to_recycle
def allocate_n_tokens_from_block_tables(
self,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
bsz: int,
n: int,
) -> List[int]:
"""Allocate logical cache blocks for `n` new tokens for a batch of sequences during decoding stage."""
assert block_tables.dim() == 2
assert context_lens.dim() == 1
bsz = block_tables.size(0) if bsz is None else bsz
assert bsz == 1, "Support bsz 1 for now" # TODO support bsz > 1
seqs_to_recycle = []
for i in range(n):
seqs_to_recycle += self.allocate_tokens_from_block_tables(block_tables, context_lens - n + i + 1, bsz)
return seqs_to_recycle
def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int: def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int:
"""Allocate space asked on a single block in the block table, specified by the provided position id, """Allocate space asked on a single block in the block table, specified by the provided position id,
and updates the provided block table with the allocated block. and updates the provided block table with the allocated block.
@ -420,9 +440,7 @@ class KVCacheManager:
Returns: Returns:
The remaining space required to be allocated (in other blocks). The remaining space required to be allocated (in other blocks).
""" """
assert ( assert block.available_space > 0, f"Found no available space left in the chosen block {block}."
block.available_space > 0
), "Tried to allocate some space but found no available space left in chosen block."
space_to_allocate = min(block.available_space, space_asked) space_to_allocate = min(block.available_space, space_asked)
block.allocate(space_to_allocate) block.allocate(space_to_allocate)
return space_asked - space_to_allocate return space_asked - space_to_allocate

View File

@ -18,6 +18,7 @@ from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import ( from colossalai.kernel.triton import (
context_attention_unpadded, context_attention_unpadded,
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding, decoding_fused_rotary_embedding,
flash_decoding_attention, flash_decoding_attention,
get_xine_cache, get_xine_cache,
@ -84,9 +85,9 @@ def llama_model_forward(
"""This function will replace the forward function of LlamaModel. """This function will replace the forward function of LlamaModel.
Args: Args:
batch (BatchInfo): It stores the necessary input information for this inference. batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache. k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache. v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
""" """
block_tables = inputmetadata.block_tables block_tables = inputmetadata.block_tables
@ -101,7 +102,25 @@ def llama_model_forward(
use_cuda_kernel = False use_cuda_kernel = False
hidden_states = self.embed_tokens(input_tokens_ids) hidden_states = self.embed_tokens(input_tokens_ids)
if use_cuda_kernel: cu_seqlens = None
# NOTE (yuanheng-zhao): we do not use cuda kernels for speculative-decoding for now
if inputmetadata.use_spec_dec:
# For speculative-decoding Prefill and Verifying Stage
if inputmetadata.is_prompts:
# output tensor shape is the same as normal Prefill Stage
o_tensor_size = (sequence_lengths.sum().item(), inputmetadata.num_heads * inputmetadata.head_dim)
rotary_indexes = [torch.arange(0, length) for length in sequence_lengths]
else:
# the number of tokens to be verified in parallel plus the correct token in the last step
n_tokens = inputmetadata.num_tokens_to_verify + 1
assert n_tokens == hidden_states.size(0)
o_tensor_size = (batch_size * n_tokens, inputmetadata.num_heads * inputmetadata.head_dim)
rotary_indexes = [(length - n_tokens + i).view(-1) for i in range(n_tokens) for length in sequence_lengths]
rotary_indexes = torch.cat(rotary_indexes, dim=-1)
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
elif use_cuda_kernel:
if inputmetadata != torch.float32 and use_flash_attn2: if inputmetadata != torch.float32 and use_flash_attn2:
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
@ -113,14 +132,22 @@ def llama_model_forward(
self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts
) )
cos_sin = (cos, sin) cos_sin = (cos, sin)
else: else:
cu_seqlens = None
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts) cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
# TODO (yuanheng-zhao): revise the logic here
# if batch.is_prompts:
# output_tensor = torch.zeros(
# (sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
# )
# else:
# output_tensor = torch.zeros(
# (batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
# )
sm_scale = 1.0 / (inputmetadata.head_dim**0.5) sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
norm_output = torch.empty_like(hidden_states) norm_output = torch.empty_like(hidden_states)
tokens_to_verify = inputmetadata.num_tokens_to_verify if inputmetadata.use_spec_dec else None
residual = None residual = None
for layer_id, decoder_layer in enumerate(self.layers): for layer_id, decoder_layer in enumerate(self.layers):
@ -131,6 +158,8 @@ def llama_model_forward(
k_cache=k_caches[layer_id], k_cache=k_caches[layer_id],
v_cache=v_caches[layer_id], v_cache=v_caches[layer_id],
is_prompts=inputmetadata.is_prompts, is_prompts=inputmetadata.is_prompts,
is_verifier=inputmetadata.use_spec_dec,
tokens_to_verify=tokens_to_verify,
sequence_lengths=sequence_lengths, sequence_lengths=sequence_lengths,
cos_sin=cos_sin, cos_sin=cos_sin,
fd_inter_tensor=inputmetadata.fd_inter_tensor, fd_inter_tensor=inputmetadata.fd_inter_tensor,
@ -144,9 +173,9 @@ def llama_model_forward(
) )
if inputmetadata.is_prompts: if inputmetadata.is_prompts:
last_token_indexs = sequence_lengths.cumsum(dim=-1) seq_len_cumsum = sequence_lengths.cumsum(dim=0)
hidden_states = hidden_states[last_token_indexs - 1].contiguous() hidden_states = hidden_states[seq_len_cumsum - 1].contiguous()
residual = residual[last_token_indexs - 1].contiguous() residual = residual[seq_len_cumsum - 1].contiguous()
norm_output = torch.empty_like(hidden_states) norm_output = torch.empty_like(hidden_states)
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel) hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
@ -164,6 +193,8 @@ def llama_decoder_layer_forward(
cos_sin: Tuple[torch.Tensor], cos_sin: Tuple[torch.Tensor],
fd_inter_tensor: FDIntermTensors, fd_inter_tensor: FDIntermTensors,
is_prompts: bool = True, is_prompts: bool = True,
is_verifier: bool = False,
tokens_to_verify: int = None,
kv_seq_len: int = 0, kv_seq_len: int = 0,
output_tensor: torch.Tensor = None, output_tensor: torch.Tensor = None,
norm_output: torch.Tensor = None, norm_output: torch.Tensor = None,
@ -202,6 +233,9 @@ def llama_decoder_layer_forward(
block_tables=block_tables, block_tables=block_tables,
k_cache=k_cache, k_cache=k_cache,
v_cache=v_cache, v_cache=v_cache,
is_prompts=is_prompts,
is_verifier=is_verifier,
tokens_to_verify=tokens_to_verify,
sequence_lengths=sequence_lengths, sequence_lengths=sequence_lengths,
cos_sin=cos_sin, cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor, fd_inter_tensor=fd_inter_tensor,
@ -312,6 +346,8 @@ class NopadLlamaAttention(LlamaAttention):
cos_sin: Tuple[torch.Tensor], cos_sin: Tuple[torch.Tensor],
fd_inter_tensor: FDIntermTensors, fd_inter_tensor: FDIntermTensors,
is_prompts: bool = True, is_prompts: bool = True,
is_verifier: bool = False,
tokens_to_verify: int = None,
kv_seq_len: int = 0, kv_seq_len: int = 0,
output_tensor: torch.Tensor = None, output_tensor: torch.Tensor = None,
sm_scale: int = None, sm_scale: int = None,
@ -355,7 +391,7 @@ class NopadLlamaAttention(LlamaAttention):
block_size = k_cache.size(-2) block_size = k_cache.size(-2)
if is_prompts: if is_prompts:
if use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
# flash attn 2 currently only supports FP16/BF16. # flash attn 2 currently only supports FP16/BF16.
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
inference_ops.context_kv_cache_memcpy( inference_ops.context_kv_cache_memcpy(
@ -404,6 +440,16 @@ class NopadLlamaAttention(LlamaAttention):
block_tables, block_tables,
high_precision, high_precision,
) )
else:
q_len = tokens_to_verify + 1 if is_verifier else 1
if is_verifier:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
copy_k_to_blocked_cache(
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
copy_k_to_blocked_cache(
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
else: else:
decoding_fused_rotary_embedding( decoding_fused_rotary_embedding(
query_states, query_states,
@ -428,8 +474,10 @@ class NopadLlamaAttention(LlamaAttention):
mid_output=fd_inter_tensor.mid_output, mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse, mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale, sm_scale=sm_scale,
q_len=q_len,
) )
attn_output = attn_output.view(-1, self.hidden_size)
attn_output = torch.mm(attn_output, self.o_proj_weight) attn_output = torch.mm(attn_output, self.o_proj_weight)
return attn_output return attn_output

View File

@ -15,93 +15,75 @@ class Drafter:
Args: Args:
model (nn.Module): The drafter model. model (nn.Module): The drafter model.
tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the drafter model. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the drafter model.
max_spec_num (int): The maximum number of tokens to speculate.
device (torch.device): The device for the drafter model. device (torch.device): The device for the drafter model.
""" """
def __init__( def __init__(
self, model: nn.Module, tokenizer: PreTrainedTokenizer, max_spec_num: int, device: torch.device = None self,
model: nn.Module,
tokenizer: PreTrainedTokenizer,
device: torch.device = None,
dtype: torch.dtype = torch.float16,
): ):
self._drafter_model = model
self._tokenizer = tokenizer self._tokenizer = tokenizer
self.max_spec_num = max_spec_num
self.do_sample = False
self.sample_fn = None
self._device = device or get_current_device() self._device = device or get_current_device()
self._past_key_values = None self._dtype = dtype
self._drafter_model = model.to(self._device)
@property self._drafter_model = model.to(self._dtype)
def past_key_values(self) -> Optional[Tuple[Tuple[torch.FloatTensor]]]: self._drafter_model.eval()
return self._past_key_values
# Debug usage for now
@property
def past_key_values_shape(self):
if self._past_key_values is None:
return []
return self._past_key_values[0][0].shape
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self._drafter_model return self._drafter_model
def reset_sample_method(self, sample_fn: callable) -> None: @staticmethod
self.do_sample = True def trim_kv_cache(
self.sample_fn = sample_fn past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], invalid_token_num: int
) -> Tuple[Tuple[torch.FloatTensor]]:
"""Trim the last `invalid_token_num` kv caches.
def clear_sample_method(self) -> None: past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values with shape
self.do_sample = False num_layers x 2 x (bsz x num_heads x seq_len x head_dim)
self.sample_fn = None invalid_token_num (int): The number of invalid tokens to trim.
"""
if past_key_values is None or invalid_token_num < 1:
return past_key_values
def reset_max_spec_num(self, n: int) -> None:
assert isinstance(n, int) and n > 1
self.max_spec_num = n
def reset_past_key_values(self, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None) -> None:
self._past_key_values = past_key_values
def trim_kv_cache(self, invalid_token_num) -> Tuple[Tuple[torch.FloatTensor]]:
# Tuple of kv cache tensors: num_layers x 2 x (bsz x num_heads x seq_len x head_dim)
# Trim the last `invalid_token_num` kv caches
# The verifier (main model) might reject `invalid_token_num` tokens,
# and so that we have to trim the invalid tokens for the kv cache of the drafter model.
assert self._past_key_values is not None
trimmed_past_key_values = [] trimmed_past_key_values = []
for layer_idx in range(len(self._past_key_values)): for layer_idx in range(len(past_key_values)):
past_key_value = self._past_key_values[layer_idx] past_key_value = past_key_values[layer_idx]
trimmed_past_key_values.append( trimmed_past_key_values.append(
( (
past_key_value[0][:, :, :-invalid_token_num, :], past_key_value[0][:, :, :-invalid_token_num, :],
past_key_value[1][:, :, :-invalid_token_num, :], past_key_value[1][:, :, :-invalid_token_num, :],
) )
) )
self._past_key_values = tuple(trimmed_past_key_values) past_key_values = tuple(trimmed_past_key_values)
return self._past_key_values return past_key_values
@torch.inference_mode() @torch.inference_mode()
def speculate( def speculate(
self, input_ids: torch.Tensor, n: int, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None self,
input_ids: torch.Tensor,
n_spec_tokens: int,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
) -> DrafterOutput: ) -> DrafterOutput:
"""Generate n tokens using the drafter model. """Generate n_spec_tokens tokens using the drafter model.
Args: Args:
input_ids (torch.Tensor): Input token ids. input_ids (torch.Tensor): Input token ids.
n (int): Number of tokens to speculate. n_spec_tokens (int): Number of tokens to speculate.
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence. past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence.
""" """
assert n_spec_tokens >= 1, f"Invalid number {n_spec_tokens} to speculate"
assert 0 <= n <= self.max_spec_num, f"Invalid number {n} to speculate" # For compatibility with transformers of versions before 4.38.0
# FIXME For compatibility with transformers 4.36.2 (versions before 4.38.0)
if input_ids.dim() == 1: if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0) input_ids = input_ids.unsqueeze(0)
if past_key_values is None:
past_key_values = self._past_key_values
logits = [] logits = []
token_ids = [] token_ids = []
for _ in range(n): for _ in range(n_spec_tokens):
outputs = self._drafter_model( outputs = self._drafter_model(
input_ids, input_ids,
return_dict=True, return_dict=True,
@ -110,16 +92,9 @@ class Drafter:
) )
next_token_logits = outputs.logits[:, -1, :] next_token_logits = outputs.logits[:, -1, :]
# Skip logits_processor for drafter model # NOTE Only use greedy search for speculating.
# As the drafter model usually has only a few layers with few parameters,
# Sample # introducing sampling will make the speculation unstable and lead to worse performance.
if self.do_sample:
if self.sample_fn is not None:
probs = self.sample_fn(next_token_logits)
else:
probs = nn.functional.softmax(next_token_logits, dim=-1)
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_token_ids = torch.argmax(next_token_logits, dim=-1) next_token_ids = torch.argmax(next_token_logits, dim=-1)
logits.append(next_token_logits) logits.append(next_token_logits)
@ -133,8 +108,6 @@ class Drafter:
speculated_length = len(token_ids) # TODO For now, only support bsz 1 speculated_length = len(token_ids) # TODO For now, only support bsz 1
logits = torch.concat(logits, dim=0) logits = torch.concat(logits, dim=0)
token_ids = torch.concat(token_ids, dim=-1) token_ids = torch.concat(token_ids, dim=-1)
# update past_key_values
self._past_key_values = past_key_values
out = DrafterOutput( out = DrafterOutput(
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values

View File

@ -44,6 +44,7 @@ def _flash_decoding_fwd_kernel(
cur_seq_idx = cur_token_idx // q_len cur_seq_idx = cur_token_idx // q_len
if cur_seq_idx >= batch_size: if cur_seq_idx >= batch_size:
return return
cur_token_off = (cur_token_idx % q_len) - q_len + 1
cur_head_idx = tl.program_id(1) cur_head_idx = tl.program_id(1)
block_start_kv = tl.program_id(2) # for splitting k/v block_start_kv = tl.program_id(2) # for splitting k/v
@ -52,7 +53,8 @@ def _flash_decoding_fwd_kernel(
# and then support calculating multiple kv cache blocks on an instance # and then support calculating multiple kv cache blocks on an instance
tl.static_assert(BLOCK_KV == BLOCK_SIZE) tl.static_assert(BLOCK_KV == BLOCK_SIZE)
# get the current (kv) sequence length # get the current (kv) sequence length
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) # cur_token_off is used as a "mask" here for spec-dec during verification process
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
if block_start_kv * BLOCK_KV >= cur_kv_seq_len: if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
return return
@ -150,7 +152,9 @@ def _flash_decoding_fwd_reduce_kernel(
return return
cur_head_idx = tl.program_id(1) cur_head_idx = tl.program_id(1)
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) # cur_token_off is used as a "mask" here for spec-dec during verification process
cur_token_off = (cur_token_idx % q_len) - q_len + 1
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
offsets_dmodel = tl.arange(0, HEAD_DIM) offsets_dmodel = tl.arange(0, HEAD_DIM)
# NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have # NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have

View File

@ -2,10 +2,15 @@ import pytest
import torch import torch
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import GenerationConfig, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.spec.drafter import Drafter from colossalai.inference.spec.drafter import Drafter
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
NUM_LAYERS = 2 NUM_LAYERS = 2
MAX_LEN = 100
@pytest.mark.parametrize("spec_num", [5]) @pytest.mark.parametrize("spec_num", [5])
@ -14,13 +19,13 @@ def test_drafter(spec_num: int):
device = get_current_device() device = get_current_device()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS) toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
toy_config.pad_token_id = toy_config.eos_token_id toy_config.pad_token_id = tokenizer.eos_token_id
drafter_model = LlamaForCausalLM(toy_config) drafter_model = LlamaForCausalLM(toy_config)
drafter_model = drafter_model.eval().cuda() drafter_model = drafter_model.eval().cuda()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
drafter = Drafter(drafter_model, tokenizer, spec_num, device=device) drafter = Drafter(drafter_model, tokenizer, device=device)
input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device) input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
out = drafter.speculate(input_ids, spec_num) out = drafter.speculate(input_ids, spec_num)
@ -29,13 +34,75 @@ def test_drafter(spec_num: int):
assert out.speculated_length == spec_num assert out.speculated_length == spec_num
assert out.next_tokens.shape == (spec_num,) assert out.next_tokens.shape == (spec_num,)
assert out.logits.shape == (spec_num, len(tokenizer)) assert out.logits.shape == (spec_num, len(tokenizer))
assert drafter._past_key_values[0][0].size(2) == out.past_key_values[0][0].size(2) == past_kv_length assert out.past_key_values[0][0].size(2) == past_kv_length
reject_num = 3 reject_num = max(0, spec_num - 1)
assert reject_num <= spec_num trimmed_past_key_values = drafter.trim_kv_cache(out.past_key_values, reject_num)
drafter.trim_kv_cache(reject_num) assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num
assert drafter._past_key_values[0][0].size(2) == past_kv_length - reject_num
def check_sd():
torch.manual_seed(123)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
# Dummy configs for testing
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
toy_config.pad_token_id = tokenizer.eos_token_id
drafter_model = LlamaForCausalLM(toy_config)
drafter_model = drafter_model.eval().cuda()
large_config = LlamaConfig(
hidden_size=4096,
intermediate_size=11008,
num_attention_heads=32,
num_hidden_layers=8,
num_key_value_heads=32,
max_position_embeddings=2048,
)
large_config.pad_token_id = tokenizer.eos_token_id
main_model = LlamaForCausalLM(large_config)
inference_config = InferenceConfig(
dtype="fp16",
micro_batch_size=1,
max_batch_size=1,
max_input_len=128,
max_output_len=128,
prefill_ratio=1.2,
block_size=16,
)
engine = InferenceEngine(main_model, tokenizer, inference_config)
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda")
generation_config = GenerationConfig(
pad_token_id=tokenizer.eos_token_id,
max_length=MAX_LEN,
eos_token_id=tokenizer.eos_token_id,
)
out, out_token_ids = engine.generate(
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
)
engine.disable_spec_dec()
engine.clear_spec_dec()
assert not engine.use_spec_dec
assert engine.drafter is None and engine.drafter_model is None
assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == MAX_LEN
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_sd()
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_spec_dec():
spawn(run_dist, nprocs=1)
if __name__ == "__main__": if __name__ == "__main__":
test_drafter(spec_num=5) test_drafter(spec_num=5)
test_spec_dec()

View File

@ -19,12 +19,19 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim) return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)
def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, q_len: int, kv_len: int, device="cuda"): def create_attention_mask(kv_lengths: torch.Tensor, bsz: int, q_len: int, kv_len: int, device="cuda"):
assert q_len <= kv_len
causal_mask = torch.full((q_len, q_len), fill_value=float("-inf"), device=device).triu(diagonal=1)
padding_mask = torch.zeros((bsz, 1, q_len, kv_len), dtype=torch.float32, device=device) padding_mask = torch.zeros((bsz, 1, q_len, kv_len), dtype=torch.float32, device=device)
for i in range(bsz): for i in range(bsz):
cur_seq_len = kv_lengths[i].item() cur_seq_len = kv_lengths[i].item()
assert cur_seq_len <= kv_len assert cur_seq_len <= kv_len
padding_mask[i, :, :, : kv_len - cur_seq_len] = float("-inf") padding_mask[i, :, :, : kv_len - cur_seq_len] = float("-inf")
padding_mask[:, :, -q_len:, -q_len:] += causal_mask
return padding_mask return padding_mask
@ -56,11 +63,13 @@ def torch_attn_ref(
attn_scores = qk / (head_dim**0.5) attn_scores = qk / (head_dim**0.5)
assert attn_scores.shape == (bsz, num_heads, q_len, kv_len), "Invalid shape of attention scores" assert attn_scores.shape == (bsz, num_heads, q_len, kv_len), "Invalid shape of attention scores"
# for left-side padding if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_len): if attention_mask.size() != (bsz, 1, q_len, kv_len):
raise ValueError(f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}") raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}"
)
attn_scores = attn_scores + attention_mask attn_scores = attn_scores + attention_mask
attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype) attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype)
out = torch.matmul(attn_weights, v) out = torch.matmul(attn_weights, v)
if out.size() != (bsz, num_heads, q_len, head_dim): if out.size() != (bsz, num_heads, q_len, head_dim):

View File

@ -6,8 +6,8 @@ from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import ( from tests.test_infer.test_ops.triton.kernel_utils import (
convert_kv_unpad_to_padded, convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v2, generate_caches_and_block_tables_v2,
prepare_padding_mask,
torch_attn_ref, torch_attn_ref,
) )
@ -91,9 +91,9 @@ def test_flash_decoding(
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_kv_len_in_b) k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_kv_len_in_b)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_kv_len_in_b) v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_kv_len_in_b)
torch_padding_mask = prepare_padding_mask(kv_lengths, bsz, q_len, max_kv_len_in_b, q.device) attention_mask = create_attention_mask(kv_lengths, bsz, q_len, max_kv_len_in_b, q.device)
out_torch = torch_attn_ref( out_torch = torch_attn_ref(
q, k_torch, v_torch, torch_padding_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM q, k_torch, v_torch, attention_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
) )
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
@ -138,6 +138,5 @@ def test_flash_decoding(
assert out_torch.shape == out_triton.shape assert out_torch.shape == out_triton.shape
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4) assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
if __name__ == "__main__": if __name__ == "__main__":
test_flash_decoding(16, 32, 32, 16, 1, True) test_flash_decoding(16, 32, 32, 16, 1, True)