mirror of https://github.com/hpcaitech/ColossalAI
[kernel/fix] Performance Optimization for Decoding Kernel and Benchmarking (#5274)
* prevent re-creating intermediate tensors * add singleton class holding intermediate values * fix triton kernel api * add benchmark in pytest * fix kernel api and add benchmark * revise flash decoding triton kernel in/out shapes * fix calling of triton kernel in modeling * fix pytest: extract to util functionspull/5283/head
parent
9e2342bde2
commit
6e487e7d3c
|
@ -6,7 +6,7 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode
|
|||
|
||||
from colossalai.inference.modeling.layers.attention import PagedAttention
|
||||
from colossalai.inference.struct import BatchInfo
|
||||
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_fwd
|
||||
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_attention
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
|
||||
|
@ -209,7 +209,15 @@ def llama_attn_forward(
|
|||
if HAS_TRITON:
|
||||
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
attn_output = flash_decoding_fwd(query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
|
||||
# TODO Add dummy transpose and squeeze on in/out tensors of the triton flash decoding kernel
|
||||
# in order to maintain compatibility. This part as well as the logics of handling query_states and attn_output
|
||||
# should be revised, as we could see in previous part of `llama_attn_forward` we still have
|
||||
# redundant transpose and the in/out of torch- and triton-version decoding kernel are not consistent.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
attn_output = flash_decoding_attention(
|
||||
query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
||||
)
|
||||
attn_output = attn_output.squeeze(1)
|
||||
else:
|
||||
attn_output = PagedAttention.pad_decoding_forward(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||
|
|
|
@ -9,7 +9,9 @@ except ImportError:
|
|||
# There may exist import error even if we have triton installed.
|
||||
if HAS_TRITON:
|
||||
from .context_attn_unpad import context_attention_unpadded
|
||||
from .flash_decoding import flash_decoding_fwd
|
||||
from .flash_decoding import flash_decoding_attention
|
||||
from .flash_decoding_utils import FDIntermTensors
|
||||
|
||||
from .rms_layernorm import rms_layernorm
|
||||
from .gptq_triton import gptq_fused_linear_triton
|
||||
from .kvcache_copy import copy_kv_to_blocked_cache
|
||||
|
@ -18,10 +20,11 @@ if HAS_TRITON:
|
|||
|
||||
__all__ = [
|
||||
"context_attention_unpadded",
|
||||
"flash_decoding_fwd",
|
||||
"flash_decoding_attention",
|
||||
"copy_kv_to_blocked_cache",
|
||||
"softmax",
|
||||
"rms_layernorm",
|
||||
"gptq_fused_linear_triton",
|
||||
"rotary_embedding",
|
||||
"FDIntermTensors",
|
||||
]
|
||||
|
|
|
@ -9,15 +9,16 @@ import triton.language as tl
|
|||
# Triton 2.1.0
|
||||
@triton.jit
|
||||
def _flash_decoding_fwd_kernel(
|
||||
Q, # [batch_size, head_num, head_dim]
|
||||
Q, # [batch_size, head_num, q_len(1), head_dim]
|
||||
KCache, # [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
VCache, # [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
block_tables, # [batch_size, max_blocks_per_sequence]
|
||||
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
|
||||
mid_o_lse, # [batch_size, head_num, kv_split_num]
|
||||
context_lengths, # [batch_size]
|
||||
kv_seq_len, # [batch_size]
|
||||
stride_qt,
|
||||
stride_qh,
|
||||
stride_ql,
|
||||
stride_qd,
|
||||
stride_cacheb,
|
||||
stride_cacheh,
|
||||
|
@ -51,7 +52,7 @@ def _flash_decoding_fwd_kernel(
|
|||
tl.static_assert(BLOCK_KV == BLOCK_SIZE)
|
||||
|
||||
# get the current (kv) sequence length from provided context lengths tensor
|
||||
cur_kv_seq_len = tl.load(context_lengths + cur_seq_idx)
|
||||
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
|
||||
|
||||
offsets_q = cur_seq_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
|
||||
q = tl.load(Q + offsets_q)
|
||||
|
@ -65,7 +66,6 @@ def _flash_decoding_fwd_kernel(
|
|||
cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)
|
||||
|
||||
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
|
||||
# TODO might want to remove if-else block?
|
||||
return
|
||||
|
||||
cur_occupied_size = tl.where(
|
||||
|
@ -132,7 +132,7 @@ def _flash_decoding_fwd_reduce_kernel(
|
|||
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
|
||||
mid_o_lse, # [batch_size, head_num, kv_split_num]
|
||||
O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim]
|
||||
context_lengths,
|
||||
kv_seq_len,
|
||||
stride_mid_ot,
|
||||
stride_mid_oh,
|
||||
stride_mid_ob,
|
||||
|
@ -141,6 +141,7 @@ def _flash_decoding_fwd_reduce_kernel(
|
|||
stride_o_lseh,
|
||||
stride_o_lseb,
|
||||
stride_ob,
|
||||
stride_ol,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
BLOCK_KV: tl.constexpr,
|
||||
|
@ -149,7 +150,7 @@ def _flash_decoding_fwd_reduce_kernel(
|
|||
cur_seq_idx = tl.program_id(0)
|
||||
cur_head_idx = tl.program_id(1)
|
||||
|
||||
cur_kv_seq_len = tl.load(context_lengths + cur_seq_idx)
|
||||
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
|
||||
# NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have
|
||||
|
@ -181,21 +182,46 @@ def _flash_decoding_fwd_reduce_kernel(
|
|||
|
||||
# Decoding Stage
|
||||
# Used with blocked KV Cache (PagedAttention)
|
||||
def flash_decoding_fwd(
|
||||
q: torch.Tensor, # [bsz(e.g.num_tokens), 1, num_heads, head_dim]
|
||||
k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
v_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
context_lengths: torch.Tensor, # [batch_size]
|
||||
block_tables: torch.Tensor, # [batch_size, max_blocks_per_sequence]
|
||||
def flash_decoding_attention(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
kv_seq_len: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
block_size: int,
|
||||
num_kv_group: int = 1,
|
||||
max_seq_len_in_batch: int = None,
|
||||
mid_output: torch.Tensor = None,
|
||||
mid_output_lse: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
kv_group_num: int = 1,
|
||||
):
|
||||
bsz, _, num_heads, head_dim = q.shape
|
||||
"""
|
||||
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
|
||||
|
||||
Args:
|
||||
q (torch.Tensor): [bsz, num_heads, q_len(1), head_dim]
|
||||
k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
kv_seq_len (torch.Tensor): [batch_size]
|
||||
records the (kv) sequence lengths incorporating past kv sequence lengths.
|
||||
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
|
||||
max_seq_len_in_batch (int): Maximum sequence length in the batch.
|
||||
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
|
||||
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
|
||||
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
|
||||
Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`.
|
||||
block_size (int): Size of each block in the blocked key/value cache.
|
||||
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
Output tensor with shape [bsz, num_heads, q_len, head_dim]
|
||||
"""
|
||||
bsz, num_heads, _, head_dim = q.shape
|
||||
|
||||
assert head_dim in {32, 64, 128, 256}
|
||||
assert context_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
||||
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
|
||||
f"Got incompatible batch size (number of seqs):\n"
|
||||
f" Conext lengths bsz {context_lengths.shape[0]}, Block tables bsz {block_tables.shape[0]}, "
|
||||
f" KV seq lengths bsz {kv_seq_len.shape[0]}, Block tables bsz {block_tables.shape[0]}, "
|
||||
f"batch size {bsz}"
|
||||
)
|
||||
assert k_cache.size(-1) == v_cache.size(-1) == block_size, (
|
||||
|
@ -203,75 +229,79 @@ def flash_decoding_fwd(
|
|||
f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-1)}, "
|
||||
f"v_cache block_size {v_cache.size(-1)}"
|
||||
)
|
||||
# NOTE `context_lengths` records the (kv) sequence lengths incorporating past kv sequence lengths.
|
||||
bsz = context_lengths.size(0) # e.g. the number of seqs
|
||||
max_seq_len = context_lengths.max().item()
|
||||
sm_scale = 1.0 / (head_dim**0.5)
|
||||
|
||||
# NOTE BLOCK_KV could be considered as block splitting the sequence on k/v
|
||||
# For now, BLOCK_KV is supposed to be equivalent with the size of physical cache block (i.e.`block_size`)
|
||||
assert block_size in {16, 32, 64, 128}
|
||||
BLOCK_KV = block_size
|
||||
|
||||
kv_max_split_num = (max_seq_len + BLOCK_KV - 1) // BLOCK_KV
|
||||
mid_o = torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device)
|
||||
mid_o_lse = torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
|
||||
sm_scale = 1.0 / (head_dim**0.5) if sm_scale is None else sm_scale
|
||||
max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch
|
||||
# For compatibility (TODO revise modeling in future)
|
||||
kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV
|
||||
mid_output = (
|
||||
torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device)
|
||||
if mid_output is None
|
||||
else mid_output
|
||||
)
|
||||
mid_output_lse = (
|
||||
torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
|
||||
if mid_output_lse is None
|
||||
else mid_output_lse
|
||||
)
|
||||
|
||||
if q.dim() == 4:
|
||||
assert q.size(1) == 1, f"q_len is supposed to be 1 but is {q.size(1)}"
|
||||
q = q.squeeze(1)
|
||||
|
||||
grid = (bsz, num_heads, triton.cdiv(max_seq_len, BLOCK_KV))
|
||||
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
|
||||
_flash_decoding_fwd_kernel[grid](
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
mid_o,
|
||||
mid_o_lse,
|
||||
context_lengths,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
kv_seq_len,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
q.stride(3),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
mid_o.stride(0),
|
||||
mid_o.stride(1),
|
||||
mid_o.stride(2),
|
||||
mid_o.stride(3),
|
||||
mid_o_lse.stride(0),
|
||||
mid_o_lse.stride(1),
|
||||
mid_o_lse.stride(2),
|
||||
mid_output.stride(0),
|
||||
mid_output.stride(1),
|
||||
mid_output.stride(2),
|
||||
mid_output.stride(3),
|
||||
mid_output_lse.stride(0),
|
||||
mid_output_lse.stride(1),
|
||||
mid_output_lse.stride(2),
|
||||
sm_scale,
|
||||
KV_GROUPS=num_kv_group,
|
||||
KV_GROUPS=kv_group_num,
|
||||
BLOCK_KV=block_size,
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_DIM=head_dim,
|
||||
)
|
||||
|
||||
output = torch.zeros_like(q)
|
||||
output = output.view(-1, output.size(-2), output.size(-1))
|
||||
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) # already overlapped
|
||||
|
||||
grid = (bsz, num_heads)
|
||||
_flash_decoding_fwd_reduce_kernel[grid](
|
||||
mid_o,
|
||||
mid_o_lse,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
output,
|
||||
context_lengths,
|
||||
mid_o.stride(0),
|
||||
mid_o.stride(1),
|
||||
mid_o.stride(2),
|
||||
mid_o.stride(3),
|
||||
mid_o_lse.stride(0),
|
||||
mid_o_lse.stride(1),
|
||||
mid_o_lse.stride(2),
|
||||
kv_seq_len,
|
||||
mid_output.stride(0),
|
||||
mid_output.stride(1),
|
||||
mid_output.stride(2),
|
||||
mid_output.stride(3),
|
||||
mid_output_lse.stride(0),
|
||||
mid_output_lse.stride(1),
|
||||
mid_output_lse.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
output.stride(2),
|
||||
output.stride(3),
|
||||
BLOCK_KV=block_size,
|
||||
HEAD_DIM=head_dim,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
import torch
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class FDIntermTensors(metaclass=SingletonMeta):
|
||||
"""Singleton class to hold tensors used for storing intermediate values in flash-decoding.
|
||||
For now, it holds intermediate output and logsumexp (which will be used in reduction step along kv)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._tensors_initialized = False
|
||||
|
||||
@property
|
||||
def is_initialized(self):
|
||||
return self._tensors_initialized
|
||||
|
||||
@property
|
||||
def mid_output(self):
|
||||
assert self.is_initialized, "Intermediate tensors not initialized yet"
|
||||
return self._mid_output
|
||||
|
||||
@property
|
||||
def mid_output_lse(self):
|
||||
assert self.is_initialized, "Intermediate tensors not initialized yet"
|
||||
return self._mid_output_lse
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
max_batch_size: int,
|
||||
num_attn_heads: int,
|
||||
kv_max_split_num: int,
|
||||
head_dim: int,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: torch.device = get_current_device(),
|
||||
) -> None:
|
||||
"""Initialize tensors.
|
||||
|
||||
Args:
|
||||
max_batch_size (int): The maximum batch size over all the model forward.
|
||||
This could be greater than the batch size in attention forward func when using dynamic batch size.
|
||||
num_attn_heads (int)): Number of attention heads.
|
||||
kv_max_split_num (int): The maximum number of blocks splitted on kv in flash-decoding algorithm.
|
||||
**The maximum length/size of blocks splitted on kv should be the kv cache block size.**
|
||||
head_dim (int): Head dimension.
|
||||
dtype (torch.dtype, optional): Data type to be assigned to intermediate tensors.
|
||||
device (torch.device, optional): Device used to initialize intermediate tensors.
|
||||
"""
|
||||
assert not self.is_initialized, "Intermediate tensors used for Flash-Decoding have been initialized."
|
||||
|
||||
self._mid_output = torch.empty(
|
||||
size=(max_batch_size, num_attn_heads, kv_max_split_num, head_dim), dtype=dtype, device=device
|
||||
)
|
||||
self._mid_output_lse = torch.empty(
|
||||
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
|
||||
)
|
||||
self._tensors_initialized = True
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
@ -17,13 +19,22 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|||
return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)
|
||||
|
||||
|
||||
def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, kv_seq_len: int, device="cuda"):
|
||||
padding_mask = torch.zeros((bsz, 1, 1, kv_seq_len), dtype=torch.float32, device=device)
|
||||
for i in range(bsz):
|
||||
cur_seq_len = kv_lengths[i].item()
|
||||
assert cur_seq_len <= kv_seq_len
|
||||
padding_mask[i, :, :, : kv_seq_len - cur_seq_len] = float("-inf")
|
||||
return padding_mask
|
||||
|
||||
|
||||
# Attention calculation adapted from HuggingFace transformers repository
|
||||
# src/transformers/models/llama/modeling_llama.py
|
||||
# https://github.com/huggingface/transformers/blob/633215ba58fe5114d8c8d32e415a04600e010701/src/transformers/models/llama/modeling_llama.py#L350
|
||||
def torch_attn_ref(
|
||||
q: torch.Tensor, # [bsz, seq_len, num_heads, head_dim]
|
||||
k: torch.Tensor, # [bsz, kv_seq_len, num_heads, head_dim]
|
||||
v: torch.Tensor, # [bsz, kv_seq_len, num_heads, head_dim]
|
||||
q: torch.Tensor, # [bsz, num_heads, q_len, head_dim]
|
||||
k: torch.Tensor, # [bsz, num_heads, kv_seq_len, head_dim]
|
||||
v: torch.Tensor, # [bsz, num_heads, kv_seq_len, head_dim]
|
||||
attention_mask: torch.Tensor, # [bsz, 1, seq_len, kv_seq_len]
|
||||
bsz: int,
|
||||
seq_len: int,
|
||||
|
@ -31,14 +42,8 @@ def torch_attn_ref(
|
|||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
assert q.shape[-1] == k.shape[-1] == v.shape[-1] == head_dim
|
||||
q = q.view(bsz, seq_len, num_heads, head_dim)
|
||||
k = k.view(bsz, kv_seq_len, num_kv_heads, head_dim)
|
||||
v = v.view(bsz, kv_seq_len, num_kv_heads, head_dim)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
# repeat kv for GQA and MQA
|
||||
# k/v won't change if kv_group_num is 1
|
||||
|
@ -49,7 +54,6 @@ def torch_attn_ref(
|
|||
|
||||
qk = torch.matmul(q, k.transpose(2, 3))
|
||||
attn_scores = qk / (head_dim**0.5)
|
||||
|
||||
assert attn_scores.shape == (bsz, num_heads, seq_len, kv_seq_len), "Invalid shape of attention scores"
|
||||
# for left-side padding
|
||||
if attention_mask.size() != (bsz, 1, seq_len, kv_seq_len):
|
||||
|
@ -77,7 +81,7 @@ def mock_alloc_block_table_and_kvcache(
|
|||
num_seqs: int,
|
||||
max_num_blocks_per_seq: int,
|
||||
block_size: int,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
|
||||
block_id = 0
|
||||
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
|
||||
|
@ -102,12 +106,10 @@ def mock_alloc_block_table_and_kvcache(
|
|||
return block_tables
|
||||
|
||||
|
||||
def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int):
|
||||
"""Allocate 1 token on the block table for each seqs in block tables.
|
||||
It won't change provided context_lengths
|
||||
"""
|
||||
|
||||
# consider max_block_id as the last physical block allocated
|
||||
def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int) -> None:
|
||||
# Allocate 1 token on the block table for each seqs in block tables.
|
||||
# It won't change provided context_lengths.
|
||||
# Consider max_block_id as the last physical block allocated
|
||||
# NOTE It assumes all the blocks preceding this block have been allocated
|
||||
max_block_id = torch.max(block_tables).item()
|
||||
# the indices on each block table representing the cache block to be allocated one more token
|
||||
|
@ -126,3 +128,36 @@ def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.T
|
|||
if new_block_ids.numel():
|
||||
new_block_alloc_local_indices = alloc_local_block_indices[require_new_block]
|
||||
block_tables[require_new_block, new_block_alloc_local_indices] = new_block_ids
|
||||
|
||||
|
||||
def generate_caches_and_block_tables(
|
||||
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
|
||||
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
|
||||
_, num_kv_heads, head_dim = k_unpad.shape
|
||||
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
|
||||
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
||||
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
||||
# Mock allocation on block tables as well as blocked kv caches
|
||||
block_tables = mock_alloc_block_table_and_kvcache(
|
||||
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
return k_cache, v_cache, block_tables
|
||||
|
||||
|
||||
def convert_kv_unpad_to_padded(
|
||||
k_unpad: torch.Tensor, kv_seq_lengths: torch.Tensor, bsz: int, max_seq_len: int
|
||||
) -> torch.Tensor:
|
||||
# Rebuild (batched) k/v with padding to be used by torch attention
|
||||
# input k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
|
||||
# returns k/v padded [bsz, num_kv_heads, max_seq_len, head_dim]
|
||||
_, num_kv_heads, head_dim = k_unpad.shape
|
||||
k_torch = torch.zeros((bsz, max_seq_len, num_kv_heads, head_dim), dtype=k_unpad.dtype, device=k_unpad.device)
|
||||
prev_len_sum = 0
|
||||
for i, seq_len in enumerate(kv_seq_lengths.tolist()):
|
||||
# left-side padding
|
||||
k_torch[i, -seq_len:, :, :] = k_unpad[prev_len_sum : prev_len_sum + seq_len]
|
||||
prev_len_sum += seq_len
|
||||
k_torch = k_torch.transpose(1, 2)
|
||||
return k_torch
|
||||
|
|
|
@ -4,7 +4,7 @@ from packaging import version
|
|||
|
||||
from colossalai.kernel.triton import context_attention_unpadded
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache, torch_attn_ref
|
||||
from tests.test_infer_ops.triton.kernel_utils import generate_caches_and_block_tables, torch_attn_ref
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
|
@ -16,6 +16,8 @@ except ImportError:
|
|||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
HEAD_DIM = 32
|
||||
|
||||
|
||||
def torch_attn_unpad(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context_lengths: torch.Tensor, num_heads: int, num_kv_heads: int
|
||||
|
@ -34,9 +36,9 @@ def torch_attn_unpad(
|
|||
mask[mask == 0.0] = float("-inf")
|
||||
|
||||
torch_attn_ref_out = torch_attn_ref(
|
||||
q[start_idx:end_idx].unsqueeze(0),
|
||||
k[start_idx:end_idx].unsqueeze(0),
|
||||
v[start_idx:end_idx].unsqueeze(0),
|
||||
q[start_idx:end_idx].unsqueeze(0).transpose(1, 2),
|
||||
k[start_idx:end_idx].unsqueeze(0).transpose(1, 2),
|
||||
v[start_idx:end_idx].unsqueeze(0).transpose(1, 2),
|
||||
mask,
|
||||
1, # set bsz as 1 as we're processing sequence one by one
|
||||
seq_len,
|
||||
|
@ -74,7 +76,6 @@ def test_context_attention(
|
|||
|
||||
num_kv_heads = num_attn_heads // kv_group_num
|
||||
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
|
||||
head_dim = 32
|
||||
max_seq_len = max_num_blocks_per_seq * block_size
|
||||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
|
@ -85,28 +86,28 @@ def test_context_attention(
|
|||
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
|
||||
num_tokens = torch.sum(context_lengths).item()
|
||||
|
||||
qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, head_dim)
|
||||
qkv = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
q, k, v = torch.split(qkv, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
|
||||
qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, HEAD_DIM)
|
||||
qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
|
||||
|
||||
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
|
||||
k_cache_torch = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
||||
k_cache_triton = torch.zeros_like(k_cache_torch)
|
||||
v_cache_torch = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
||||
v_cache_triton = torch.zeros_like(v_cache_torch)
|
||||
|
||||
# Mock allocation on block tables
|
||||
block_tables = mock_alloc_block_table_and_kvcache(
|
||||
k, v, k_cache_torch, v_cache_torch, context_lengths, bsz, max_num_blocks_per_seq, block_size
|
||||
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables(
|
||||
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
block_tables = block_tables.to(device=device)
|
||||
k_cache_triton = torch.zeros_like(k_cache_ref)
|
||||
v_cache_triton = torch.zeros_like(v_cache_ref)
|
||||
|
||||
out_triton = context_attention_unpadded(
|
||||
q, k, v, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size
|
||||
q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size
|
||||
)
|
||||
|
||||
out_torch = torch_attn_unpad(q, k, v, context_lengths, num_attn_heads, num_kv_heads)
|
||||
out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads)
|
||||
|
||||
assert out_torch.shape == out_triton.shape
|
||||
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
|
||||
assert torch.allclose(k_cache_torch, k_cache_triton)
|
||||
assert torch.allclose(v_cache_torch, v_cache_triton)
|
||||
assert torch.allclose(out_torch, out_triton, atol=1e-3)
|
||||
assert torch.equal(k_cache_ref, k_cache_triton)
|
||||
assert torch.equal(v_cache_ref, v_cache_triton)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_context_attention(4, 32, 8, 16, 1, True)
|
||||
|
|
|
@ -2,9 +2,14 @@ import pytest
|
|||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.kernel.triton import flash_decoding_fwd
|
||||
from colossalai.kernel.triton import flash_decoding_attention
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache, torch_attn_ref
|
||||
from tests.test_infer_ops.triton.kernel_utils import (
|
||||
convert_kv_unpad_to_padded,
|
||||
generate_caches_and_block_tables,
|
||||
prepare_padding_mask,
|
||||
torch_attn_ref,
|
||||
)
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
|
@ -16,23 +21,37 @@ except ImportError:
|
|||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
Q_LEN = 1
|
||||
HEAD_DIM = 128
|
||||
|
||||
def torch_decoding(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context_lengths: torch.Tensor):
|
||||
assert context_lengths.dim() == 1, "context_lengths should be a 1D tensor"
|
||||
assert q.size(1) == 1, "Only used for decoding"
|
||||
assert k.shape == v.shape
|
||||
|
||||
bsz, _, num_heads, head_dim = q.shape
|
||||
_, kv_seq_len, num_kv_heads, _ = k.shape
|
||||
assert num_heads % num_kv_heads == 0, "Invalid kv heads and attention heads."
|
||||
padding_mask = torch.zeros((bsz, 1, 1, kv_seq_len), dtype=torch.float32, device=q.device)
|
||||
for i in range(bsz):
|
||||
cur_seq_len = context_lengths[i].item()
|
||||
assert cur_seq_len <= kv_seq_len
|
||||
padding_mask[i, :, :, : kv_seq_len - cur_seq_len] = float("-inf")
|
||||
def prepare_data(
|
||||
bsz: int,
|
||||
num_attn_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
same_context_len: bool,
|
||||
q_len: int,
|
||||
max_kv_seq_len: int,
|
||||
dtype=torch.float16,
|
||||
device="cuda",
|
||||
):
|
||||
# Use the provided maximum sequence length for each sequence when testing with teh same context length,
|
||||
# otherwise generate random context lengths.
|
||||
kv_lengths = (
|
||||
torch.tensor([max_kv_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||
if same_context_len
|
||||
else torch.randint(low=1, high=max_kv_seq_len, size=(bsz,), dtype=torch.int32, device=device)
|
||||
)
|
||||
num_tokens = torch.sum(kv_lengths).item()
|
||||
|
||||
out = torch_attn_ref(q, k, v, padding_mask, bsz, 1, kv_seq_len, num_heads, num_kv_heads, head_dim)
|
||||
return out
|
||||
q_size = (bsz, q_len, num_attn_heads, head_dim)
|
||||
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2)
|
||||
kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
|
||||
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
|
||||
|
||||
return q, k_unpad, v_unpad, kv_lengths
|
||||
|
||||
|
||||
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||
|
@ -57,59 +76,135 @@ def test_flash_decoding(
|
|||
|
||||
num_kv_heads = num_attn_heads // kv_group_num
|
||||
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
|
||||
q_len = 1
|
||||
head_dim = 128
|
||||
max_seq_len = block_size * max_num_blocks_per_seq
|
||||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
|
||||
if same_context_len:
|
||||
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||
else:
|
||||
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
|
||||
num_tokens = torch.sum(context_lengths).item()
|
||||
|
||||
q_size = (bsz, q_len, num_attn_heads, head_dim)
|
||||
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
|
||||
kv = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
k, v = torch.split(kv, [num_kv_heads, num_kv_heads], dim=-2)
|
||||
|
||||
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
|
||||
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
||||
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
||||
# Mock allocation on block tables as well as blocked kv caches
|
||||
block_tables = mock_alloc_block_table_and_kvcache(
|
||||
k, v, k_cache, v_cache, context_lengths, bsz, max_num_blocks_per_seq, block_size
|
||||
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
|
||||
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device
|
||||
)
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables(
|
||||
k_unpad, v_unpad, kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
block_tables = block_tables.to(device=device)
|
||||
|
||||
q = q.view(bsz, q_len, num_attn_heads, head_dim)
|
||||
out_triton = flash_decoding_fwd(
|
||||
# The maximum sequence length in the batch (if context lengths randomly generated)
|
||||
max_seq_len_in_b = kv_seq_lengths.max().item()
|
||||
# The maximum block length splitted on kv should be the kv cache block size
|
||||
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
||||
mid_output = torch.empty(
|
||||
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
||||
)
|
||||
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
|
||||
sm_scale = 1.0 / (HEAD_DIM**0.5)
|
||||
out_triton = flash_decoding_attention(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
context_lengths,
|
||||
kv_seq_lengths,
|
||||
block_tables,
|
||||
block_size,
|
||||
kv_group_num,
|
||||
)
|
||||
out_triton = out_triton.unsqueeze(1) # [bsz, 1, num_heads, head_dim]
|
||||
max_seq_len_in_b,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
kv_group_num=kv_group_num,
|
||||
) # [bsz, 1, num_heads, head_dim]
|
||||
|
||||
# rebuild (batched) kv with padding for torch attention
|
||||
# q [bsz, 1, num_heads, head_dim]
|
||||
# k/v [num_tokens, num_kv_heads, head_dim]
|
||||
max_seq_len = context_lengths.max().item()
|
||||
k_torch = torch.zeros((bsz, max_seq_len, num_kv_heads, head_dim), dtype=k.dtype, device=k.device)
|
||||
v_torch = torch.zeros_like(k_torch)
|
||||
prev_len_sum = 0
|
||||
for i, seq_len in enumerate(context_lengths.tolist()):
|
||||
# mock left-side padding
|
||||
k_torch[i, -seq_len:, :, :] = k[prev_len_sum : prev_len_sum + seq_len]
|
||||
v_torch[i, -seq_len:, :, :] = v[prev_len_sum : prev_len_sum + seq_len]
|
||||
prev_len_sum += seq_len
|
||||
# k/v [bsz, max_seq_len, num_kv_heads, head_dim]
|
||||
out_torch = torch_decoding(q, k_torch, v_torch, context_lengths)
|
||||
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, bsz, max_seq_len_in_b)
|
||||
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, bsz, max_seq_len_in_b)
|
||||
torch_padding_mask = prepare_padding_mask(kv_seq_lengths, bsz, max_seq_len_in_b, q.device)
|
||||
out_torch = torch_attn_ref(
|
||||
q, k_torch, v_torch, torch_padding_mask, bsz, 1, max_seq_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
|
||||
)
|
||||
|
||||
assert out_torch.shape == out_triton.shape
|
||||
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
|
||||
|
||||
|
||||
BATCH = 16
|
||||
BLOCK_SIZE = 32
|
||||
SAME_LEN = True
|
||||
WARM_UPS = 10
|
||||
REPS = 100
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["KV_LEN"],
|
||||
x_vals=[2**i for i in range(8, 14)],
|
||||
# x_vals=[x for x in range(256, 8192, 256)],
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "triton"],
|
||||
line_names=["Torch", "Triton"],
|
||||
styles=[("red", "-"), ("blue", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"decoding-block_size-{BLOCK_SIZE}-batch{BATCH}",
|
||||
args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def bench_kernel(
|
||||
bsz,
|
||||
KV_LEN,
|
||||
provider,
|
||||
block_size: int,
|
||||
kv_group_num: int,
|
||||
same_context_len: bool,
|
||||
):
|
||||
num_attn_heads = 16
|
||||
max_num_blocks_per_seq = triton.cdiv(KV_LEN, block_size)
|
||||
max_seq_len = block_size * max_num_blocks_per_seq
|
||||
|
||||
num_kv_heads = num_attn_heads // kv_group_num
|
||||
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
|
||||
block_size * max_num_blocks_per_seq
|
||||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
|
||||
q, k_unpad, v_unpad, kv_lengths = prepare_data(
|
||||
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device
|
||||
)
|
||||
max_seq_len_in_b = kv_lengths.max().item() # for random lengths
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "torch":
|
||||
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_seq_len_in_b)
|
||||
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_seq_len_in_b)
|
||||
torch_padding_mask = prepare_padding_mask(kv_lengths, bsz, max_seq_len_in_b, q.device)
|
||||
fn = lambda: torch_attn_ref(
|
||||
q, k_torch, v_torch, torch_padding_mask, bsz, 1, max_seq_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
|
||||
)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
||||
if provider == "triton":
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables(
|
||||
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
block_tables = block_tables.to(device=device)
|
||||
# the maximum block length splitted on kv should be the kv cache block size
|
||||
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
||||
mid_output = torch.empty(
|
||||
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
||||
)
|
||||
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
|
||||
sm_scale = 1.0 / (HEAD_DIM**0.5)
|
||||
fn = lambda: flash_decoding_attention(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
kv_lengths,
|
||||
block_tables,
|
||||
block_size,
|
||||
max_seq_len_in_b,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
kv_group_num=kv_group_num,
|
||||
) # [bsz, 1, num_heads, head_dim]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
||||
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_flash_decoding(16, 32, 32, 16, 1, True)
|
||||
# bench_kernel.run(save_path=".", print_data=True)
|
||||
|
|
Loading…
Reference in New Issue