[kernel] Support New KCache Layout - Triton Kernel (#5677)

* kvmemcpy triton for new kcache layout

* revise tests for new kcache layout

* naive triton flash decoding - new kcache layout

* rotary triton kernel - new kcache layout

* remove redundancy - triton decoding

* remove redundancy - triton kvcache copy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5685/head
Yuanheng Zhao 2024-05-03 17:20:45 +08:00 committed by GitHub
parent 9df016fc45
commit 537a3cbc4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 428 additions and 206 deletions

View File

@ -338,8 +338,8 @@ def _fwd_context_paged_attention_kernel_v2(
X_range = tl.arange(0, KCACHE_X) X_range = tl.arange(0, KCACHE_X)
# unroll the loop aggressively # unroll the loop aggressively
for split_x in tl.static_range(HEAD_DIM // KCACHE_X): for split_x in tl.static_range(HEAD_DIM // KCACHE_X):
offsets_dmodel_x_partion = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X) offsets_dmodel_x_partition = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X)
offsets_k = K + offset_kv + offsets_dmodel_x_partion[None, :] * stride_kd + offsets_m[:, None] * stride_kt offsets_k = K + offset_kv + offsets_dmodel_x_partition[None, :] * stride_kd + offsets_m[:, None] * stride_kt
k = tl.load(offsets_k, mask=offsets_m[:, None] < cur_seq_len, other=0.0) k = tl.load(offsets_k, mask=offsets_m[:, None] < cur_seq_len, other=0.0)
# HACK: KCache must be contiguous in order to apply the following offsets calculation # HACK: KCache must be contiguous in order to apply the following offsets calculation
offsets_kcache = ( offsets_kcache = (

View File

@ -11,20 +11,29 @@ import triton.language as tl
def _flash_decoding_fwd_kernel( def _flash_decoding_fwd_kernel(
Q, # [batch_size * q_len, head_num, head_dim] Q, # [batch_size * q_len, head_num, head_dim]
KCache, # [num_blocks, num_kv_heads, block_size, head_dim] KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
VCache, # [num_blocks, num_kv_heads, block_size, head_dim] VCache, # [num_blocks, num_kv_heads, block_size, head_dim],
# or [num_blocks, num_kv_heads, head_dim//x, block_size, x], depends on strides provided
block_tables, # [batch_size, max_blocks_per_sequence] block_tables, # [batch_size, max_blocks_per_sequence]
mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim] mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim]
mid_o_lse, # [batch_size * q_len, head_num, kv_split_num] mid_o_lse, # [batch_size * q_len, head_num, kv_split_num]
kv_seq_len, # [batch_size] kv_seq_len, # [batch_size]
q_len, q_len,
batch_size, batch_size,
kv_group_num,
x,
sm_scale,
stride_qt, stride_qt,
stride_qh, stride_qh,
stride_qd, stride_qd,
stride_cacheb, stride_kcb,
stride_cacheh, stride_kch,
stride_cachebs, stride_kcsplit_x,
stride_cached, stride_kcs,
stride_kcd,
stride_vcb,
stride_vch,
stride_vcs,
stride_vcd,
stride_bts, stride_bts,
stride_btb, stride_btb,
stride_mid_ot, stride_mid_ot,
@ -34,8 +43,6 @@ def _flash_decoding_fwd_kernel(
stride_mid_o_lset, stride_mid_o_lset,
stride_mid_o_lseh, stride_mid_o_lseh,
stride_mid_o_lseb, stride_mid_o_lseb,
sm_scale,
KV_GROUPS: tl.constexpr,
BLOCK_KV: tl.constexpr, BLOCK_KV: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
HEAD_DIM: tl.constexpr, HEAD_DIM: tl.constexpr,
@ -57,10 +64,9 @@ def _flash_decoding_fwd_kernel(
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
if block_start_kv * BLOCK_KV >= cur_kv_seq_len: if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
return return
offsets_dmodel = tl.arange(0, HEAD_DIM) offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd offsets_block = tl.arange(0, BLOCK_SIZE)
q = tl.load(Q + offsets_q)
# block table for the current sequence # block table for the current sequence
block_table_ptr = block_tables + cur_seq_idx * stride_bts block_table_ptr = block_tables + cur_seq_idx * stride_bts
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE) # cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
@ -71,25 +77,25 @@ def _flash_decoding_fwd_kernel(
) )
tl.device_assert(cur_occupied_size >= 0) tl.device_assert(cur_occupied_size >= 0)
cur_kv_head_idx = cur_head_idx // KV_GROUPS offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh q = tl.load(Q + offsets_q)
K_block_ptr = tl.make_block_ptr( cur_kv_head_idx = cur_head_idx // kv_group_num
base=KCache + offset_kvcache, offset_kvcache = cur_block_id * stride_kcb + cur_kv_head_idx * stride_kch
shape=(cur_occupied_size, HEAD_DIM), offsets_k = (
strides=(stride_cachebs, stride_cached), offset_kvcache
offsets=(0, 0), + (offsets_dmodel[None, :] // x) * stride_kcsplit_x
block_shape=(BLOCK_SIZE, HEAD_DIM), + (offsets_dmodel[None, :] % x) * stride_kcd
order=(0, 1), + offsets_block[:, None] * stride_kcs
) )
k_cur_block = tl.load(KCache + offsets_k)
V_block_ptr = tl.make_block_ptr( V_block_ptr = tl.make_block_ptr(
base=VCache + offset_kvcache, base=VCache + offset_kvcache,
shape=(cur_occupied_size, HEAD_DIM), shape=(cur_occupied_size, HEAD_DIM),
strides=(stride_cachebs, stride_cached), strides=(stride_vcs, stride_vcd),
offsets=(0, 0), offsets=(0, 0),
block_shape=(BLOCK_SIZE, HEAD_DIM), block_shape=(BLOCK_SIZE, HEAD_DIM),
order=(0, 1), order=(0, 1),
) )
k_cur_block = tl.load(K_block_ptr)
v_cur_block = tl.load(V_block_ptr) v_cur_block = tl.load(V_block_ptr)
acc = tl.zeros([HEAD_DIM], dtype=tl.float32) acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
# use block size of the paged/blocked kv cache # use block size of the paged/blocked kv cache
@ -100,7 +106,7 @@ def _flash_decoding_fwd_kernel(
# Refer to https://github.com/openai/triton/discussions/895 # Refer to https://github.com/openai/triton/discussions/895
S_ij += tl.sum(q[None, :] * k_cur_block, 1) S_ij += tl.sum(q[None, :] * k_cur_block, 1)
S_ij *= sm_scale S_ij *= sm_scale
S_ij += tl.where(block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) < cur_kv_seq_len, 0, float("-inf")) S_ij += tl.where(block_start_kv * BLOCK_KV + offsets_block < cur_kv_seq_len, 0, float("-inf"))
m = tl.max(S_ij, 0) m = tl.max(S_ij, 0)
S_ij -= m S_ij -= m
@ -324,6 +330,7 @@ def flash_decoding_attention(
sm_scale: int = None, sm_scale: int = None,
kv_group_num: int = 1, kv_group_num: int = 1,
q_len: int = 1, # NOTE alibi flash decoding does not support q_len > 1 at this moment. q_len: int = 1, # NOTE alibi flash decoding does not support q_len > 1 at this moment.
use_new_kcache_layout: bool = False,
): ):
""" """
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage. Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
@ -349,6 +356,7 @@ def flash_decoding_attention(
num_kv_group (int, optional): Number of key/value groups. Defaults to 1. num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens). q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens).
Defaults to 1. Defaults to 1.
use_new_kcache_layout (bool): Whether to use the new kcache layout. Defaults to False.
Returns: Returns:
Output tensor with shape [bsz * q_len, num_heads * head_dim] Output tensor with shape [bsz * q_len, num_heads * head_dim]
@ -400,13 +408,20 @@ def flash_decoding_attention(
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton # NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred) # To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
grid = ( grid = lambda META: (
triton.next_power_of_2(bsz * q_len), triton.next_power_of_2(bsz * q_len),
num_heads, num_heads,
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV), triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), META["BLOCK_KV"]),
) )
if alibi_slopes is not None: if alibi_slopes is not None:
# TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one,
# the code (alibi kernel) will be refactored later to avoid code duplication, when
# the whole triton flow with new k cache layout has been supported and tested.
assert (
not use_new_kcache_layout
), "Alibi Slopes will be supported with new kcache layout later when the whole triton flow is ready"
_alibi_flash_decoding_fwd_kernel[grid]( _alibi_flash_decoding_fwd_kernel[grid](
q, q,
k_cache, k_cache,
@ -441,6 +456,19 @@ def flash_decoding_attention(
HEAD_DIM=head_dim, HEAD_DIM=head_dim,
) )
else: else:
# For KCache and VCache with the same layout
x = head_dim
kcsplit_x_stride, kcs_stride, kcd_stride = 0, k_cache.stride(2), k_cache.stride(3)
# For KCache layout [num_blocks, num_kv_heads, head_dim//x, block_size, x]
if use_new_kcache_layout:
assert (
k_cache.dim() == 5
and k_cache.shape[1] == v_cache.shape[1]
and k_cache.shape[2] * k_cache.shape[4] == v_cache.shape[3]
), f"Invalid KCache shape {k_cache.shape} and VCache shape {v_cache.shape}"
x = k_cache.size(-1)
kcsplit_x_stride, kcs_stride, kcd_stride = k_cache.stride()[-3:]
_flash_decoding_fwd_kernel[grid]( _flash_decoding_fwd_kernel[grid](
q, q,
k_cache, k_cache,
@ -451,13 +479,21 @@ def flash_decoding_attention(
kv_seq_len, kv_seq_len,
q_len, q_len,
bsz, bsz,
kv_group_num,
x,
sm_scale,
q.stride(0), q.stride(0),
q.stride(1), q.stride(1),
q.stride(2), q.stride(2),
k_cache.stride(0), k_cache.stride(0),
k_cache.stride(1), k_cache.stride(1),
k_cache.stride(2), kcsplit_x_stride,
k_cache.stride(3), kcs_stride,
kcd_stride,
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(3),
block_tables.stride(0), block_tables.stride(0),
block_tables.stride(1), block_tables.stride(1),
mid_output.stride(0), mid_output.stride(0),
@ -467,8 +503,6 @@ def flash_decoding_attention(
mid_output_lse.stride(0), mid_output_lse.stride(0),
mid_output_lse.stride(1), mid_output_lse.stride(1),
mid_output_lse.stride(2), mid_output_lse.stride(2),
sm_scale,
KV_GROUPS=kv_group_num,
BLOCK_KV=block_size, BLOCK_KV=block_size,
BLOCK_SIZE=block_size, BLOCK_SIZE=block_size,
HEAD_DIM=head_dim, HEAD_DIM=head_dim,

View File

@ -4,56 +4,69 @@ import triton.language as tl
# Triton 2.1.0 # Triton 2.1.0
# supports two types of cache layouts
# 1. [num_blocks, num_kv_heads, block_size, head_dim]
# 2. [num_blocks, num_kv_heads, head_dim // x, block_size, x]
@triton.jit @triton.jit
def _copy_to_kcache_seqlen_n_kernel( def _copy_to_kcache_seqlen_n_kernel(
KV, # K or V K, # K or V
KVCache, # KCache or VCache KCache, # [num_blocks, num_kv_heads, head_dim // x, block_size, x]
BLOCK_TABLES, BLOCK_TABLES,
context_lengths, seq_lengths,
stride_kt, stride_kt,
stride_kh, stride_kh,
stride_kd, stride_kd,
stride_cacheb, stride_kcb,
stride_cacheh, stride_kch,
stride_cachebs, stride_kcsplit_x,
stride_cached, stride_kcs,
stride_kcx,
stride_bts, stride_bts,
stride_btb, stride_btb,
block_size, block_size,
n, n_tokens,
HEAD_DIM: tl.constexpr, HEAD_DIM: tl.constexpr,
KCACHE_X: tl.constexpr,
): ):
# `n_tokens` is used to specify the number of tokens to copy for each sequence
# When n_tokens > 1, tokens from different sequences are packed into the first dimension of the grid,
# `seq_lengths` must be the lengths of sequences counting the number of tokens to copy
# E.g. if n_tokens = 5, seq_lengths = [12, 15], then the already-copied position ids are [0-6, 0-9]
# for the two sequences, respectively. And the position ids to be copied are [7-11, 9-14].
# When n_tokens = 1, consider token idx as the sequence idx, since it's only used during regular decoding stage
cur_token_idx = tl.program_id(0) cur_token_idx = tl.program_id(0)
cur_seq_idx = cur_token_idx // n cur_seq_idx = cur_token_idx // n_tokens
cur_token_shift = cur_token_idx - (n * (cur_seq_idx + 1)) # `cur_token_shift` is only valid and functional when `n_tokens` > 1
# cur_token_shift = cur_token_idx - n * cur_seq_idx cur_token_shift = cur_token_idx - (n_tokens * (cur_seq_idx + 1))
cur_kv_head_idx = tl.program_id(1) cur_kv_head_idx = tl.program_id(1)
split_x_idx = tl.program_id(2)
past_kv_seq_len = tl.load(context_lengths + cur_seq_idx) + cur_token_shift past_kv_seq_len = tl.load(seq_lengths + cur_seq_idx) + cur_token_shift
last_bt_block_idx = past_kv_seq_len // block_size last_bt_block_idx = past_kv_seq_len // block_size
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb) block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
offset_last_block = past_kv_seq_len % block_size offset_last_block = past_kv_seq_len % block_size
offsets_dmodel = tl.arange(0, HEAD_DIM) offsets_dmodel = split_x_idx * KCACHE_X + tl.arange(0, KCACHE_X)
offsets_kv = cur_token_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd offsets_k = cur_token_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
kv = tl.load(KV + offsets_kv) k = tl.load(K + offsets_k)
offsets_kvcache = ( offsets_kcache = (
block_id * stride_cacheb block_id * stride_kcb
+ cur_kv_head_idx * stride_cacheh + cur_kv_head_idx * stride_kch
+ offset_last_block * stride_cachebs + split_x_idx * stride_kcsplit_x
+ offsets_dmodel * stride_cached + offset_last_block * stride_kcs
+ tl.arange(0, KCACHE_X)
) )
tl.store(KVCache + offsets_kvcache, kv) tl.store(KCache + offsets_kcache, k)
return return
# Triton 2.1.0 # Triton 2.1.0
@triton.jit @triton.jit
def _copy_to_kvcache_seqlen1_kernel( def _copy_to_kvcache_seqlen1_kernel(
K, # K K,
V, # V V,
KCache, # KCache KCache,
VCache, # VCache VCache,
BLOCK_TABLES, BLOCK_TABLES,
context_lengths, context_lengths,
stride_kt, stride_kt,
@ -62,18 +75,20 @@ def _copy_to_kvcache_seqlen1_kernel(
stride_vt, stride_vt,
stride_vh, stride_vh,
stride_vd, stride_vd,
stride_cachekb, stride_kcb,
stride_cachekh, stride_kch,
stride_cachekbs, stride_kcsplit_x,
stride_cachekd, stride_kcs,
stride_cachevb, stride_kcd,
stride_cachevh, stride_vcb,
stride_cachevbs, stride_vch,
stride_cachevd, stride_vcs,
stride_vcd,
stride_bts, stride_bts,
stride_btb, stride_btb,
block_size, block_size,
HEAD_DIM: tl.constexpr, HEAD_DIM: tl.constexpr,
KCACHE_X: tl.constexpr,
): ):
cur_seq_idx = tl.program_id(0) cur_seq_idx = tl.program_id(0)
cur_kv_head_idx = tl.program_id(1) cur_kv_head_idx = tl.program_id(1)
@ -83,33 +98,42 @@ def _copy_to_kvcache_seqlen1_kernel(
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb) block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
offsets_in_last_block = past_kv_seq_len % block_size offsets_in_last_block = past_kv_seq_len % block_size
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_k = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
offsets_v = cur_seq_idx * stride_vt + cur_kv_head_idx * stride_vh + offsets_dmodel * stride_vd
k = tl.load(K + offsets_k) range_x = tl.arange(0, KCACHE_X)
v = tl.load(V + offsets_v) offsets_dmodel_x_partition = tl.arange(0, KCACHE_X)
offsets_kcache = ( for split_x in tl.static_range(HEAD_DIM // KCACHE_X):
block_id * stride_cachekb offsets_dmodel_x_partition = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X)
+ cur_kv_head_idx * stride_cachekh offsets_k = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel_x_partition * stride_kd
+ offsets_in_last_block * stride_cachekbs k = tl.load(K + offsets_k)
+ offsets_dmodel * stride_cachekd offsets_v = cur_seq_idx * stride_vt + cur_kv_head_idx * stride_vh + offsets_dmodel_x_partition * stride_vd
) v = tl.load(V + offsets_v)
offsets_vcache = (
block_id * stride_cachevb
+ cur_kv_head_idx * stride_cachevh
+ offsets_in_last_block * stride_cachevbs
+ offsets_dmodel * stride_cachevd
)
tl.store(KCache + offsets_kcache, k) offsets_kcache = (
tl.store(VCache + offsets_vcache, v) block_id * stride_kcb
+ cur_kv_head_idx * stride_kch
+ split_x * stride_kcsplit_x
+ offsets_in_last_block * stride_kcs
+ range_x
)
tl.store(KCache + offsets_kcache, k)
offsets_vcache = (
block_id * stride_vcb
+ cur_kv_head_idx * stride_vch
+ offsets_in_last_block * stride_vcs
+ offsets_dmodel_x_partition * stride_vcd
)
tl.store(VCache + offsets_vcache, v)
return return
def copy_k_to_blocked_cache( def copy_k_to_blocked_cache(
k: torch.Tensor, k_cache: torch.Tensor, kv_lengths: torch.Tensor, block_tables: torch.Tensor, n: int = 1 k: torch.Tensor,
k_cache: torch.Tensor,
kv_lengths: torch.Tensor,
block_tables: torch.Tensor,
n: int = 1,
use_new_kcache_layout: bool = False,
): ):
""" """
Copy keys or values to the blocked key/value cache during decoding stage. Copy keys or values to the blocked key/value cache during decoding stage.
@ -118,16 +142,17 @@ def copy_k_to_blocked_cache(
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1. k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
[bsz * n, num_kv_heads, head_dim] - Keys or values with seq len n [bsz * n, num_kv_heads, head_dim] - Keys or values with seq len n
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache. k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
new KCache Layout [num_blocks, num_kv_heads, head_dim // x, block_size, x]
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence. kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence. block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
n (int): Number of tokens to copy for each sequence. Default to 1. n (int): Number of tokens to copy for each sequence. Default to 1.
use_new_kcache_layout (bool): Whether to use the new layout for kcache. Default to False.
""" """
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache." assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
if k.dim() == 4:
k = k.reshape(-1, k.size(-2), k.size(-1)) if k.dim() == 4 else k k = k.reshape(-1, k.size(-2), k.size(-1))
assert k.dim() == 3, f"Invalid k dim {k.dim()}" k_shape = k.shape
bsz, num_kv_heads, head_dim = k.shape bsz, num_kv_heads, head_dim = k_shape
# NOTE when n > 1, the shape of k is [bsz * n, num_kv_heads, head_dim] # NOTE when n > 1, the shape of k is [bsz * n, num_kv_heads, head_dim]
if n > 1: if n > 1:
assert bsz % n == 0, "Each sequence should have the same number of tokens to be copied" assert bsz % n == 0, "Each sequence should have the same number of tokens to be copied"
@ -139,12 +164,24 @@ def copy_k_to_blocked_cache(
f" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}" f" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}"
) )
k_cache_shape = k_cache.shape
# Modify if the shape of kv cahce is changed. # Modify if the shape of kv cahce is changed.
block_size = k_cache.size(-2) block_size = k_cache_shape[-2]
x = head_dim
stride_kcsplit_x, stride_kcs, stride_kcd = 0, k_cache.stride(2), k_cache.stride(3)
if use_new_kcache_layout:
# when using kcache layout [num_blocks, num_kv_heads, head_dim // x, block_size, x]
assert (
len(k_cache_shape) == 5
and k_cache_shape[1] == k_shape[1]
and k_cache_shape[2] * k_cache_shape[4] == k_shape[2]
), f"Incompatible k_cache shape {k_cache_shape} with k shape {k_shape}"
x = k_cache.size(-1)
stride_kcsplit_x, stride_kcs, stride_kcd = k_cache.stride()[2:]
num_warps = 8 if head_dim > 128 else 4 num_warps = 8 if head_dim > 128 else 4
grid = (bsz * n, num_kv_heads, head_dim // x)
grid = (bsz * n, num_kv_heads)
_copy_to_kcache_seqlen_n_kernel[grid]( _copy_to_kcache_seqlen_n_kernel[grid](
k, k,
k_cache, k_cache,
@ -155,13 +192,15 @@ def copy_k_to_blocked_cache(
k.stride(2), k.stride(2),
k_cache.stride(0), k_cache.stride(0),
k_cache.stride(1), k_cache.stride(1),
k_cache.stride(2), stride_kcsplit_x,
k_cache.stride(3), stride_kcs,
stride_kcd,
block_tables.stride(0), block_tables.stride(0),
block_tables.stride(1), block_tables.stride(1),
block_size, block_size,
n=n, n_tokens=n,
HEAD_DIM=head_dim, HEAD_DIM=head_dim,
KCACHE_X=x,
num_warps=num_warps, num_warps=num_warps,
) )
@ -173,6 +212,7 @@ def copy_kv_to_blocked_cache(
v_cache: torch.Tensor, v_cache: torch.Tensor,
kv_lengths: torch.Tensor, kv_lengths: torch.Tensor,
block_tables: torch.Tensor, block_tables: torch.Tensor,
use_new_kcache_layout: bool = False,
): ):
""" """
Copy keys or values to the blocked key/value cache during decoding stage. Copy keys or values to the blocked key/value cache during decoding stage.
@ -184,19 +224,30 @@ def copy_kv_to_blocked_cache(
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked value cache. v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked value cache.
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence. kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence. block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
use_new_kcache_layout (bool): Whether to use the new layout for kcache. Default to False.
""" """
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim" k_cache_shape = k_cache.shape
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache." v_cache_shape = v_cache.shape
if use_new_kcache_layout:
assert (
len(k_cache_shape) == 5
and k_cache_shape[1] == v_cache_shape[1]
and k_cache_shape[2] * k_cache_shape[4] == v_cache_shape[3]
), f"Invalid KCache shape {k_cache_shape} and VCache shape {v_cache_shape}"
else:
assert k.size(-1) == k_cache_shape[-1], "Incompatible head dim"
assert (
k_cache_shape == v_cache_shape
), f"Incompatible KCache shape {k_cache_shape} and VCache shape {v_cache_shape}"
assert v.size(-1) == v_cache_shape[-1], "Incompatible head dim"
k = k.squeeze(1) if k.dim() == 4 else k k = k.squeeze(1) if k.dim() == 4 else k
assert k.dim() == 3, f"Incompatible k dim {k.dim()}" assert k.dim() == 3, f"Incompatible k dim {k.dim()}"
assert v.size(-1) == v_cache.size(-1), "Incompatible head dim"
assert v.dtype == v_cache.dtype, "Expected consistent dtype for tensor and cache."
v = v.squeeze(1) if v.dim() == 4 else v v = v.squeeze(1) if v.dim() == 4 else v
assert v.dim() == 3, f"Incompatible v dim {v.dim()}" assert v.dim() == 3, f"Incompatible v dim {v.dim()}"
bsz, num_kv_heads, head_dim = k.shape bsz, num_kv_heads, head_dim = k.shape
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, ( assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
f"Got incompatible batch size (number of seqs):\n" f"Got incompatible batch size (number of seqs):\n"
f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; " f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; "
@ -206,6 +257,12 @@ def copy_kv_to_blocked_cache(
# Modify if the shape of kv cahce is changed. # Modify if the shape of kv cahce is changed.
block_size = k_cache.size(-2) block_size = k_cache.size(-2)
x = head_dim
stride_kcsplit_x, stride_kcs, stride_kcd = 0, k_cache.stride(2), k_cache.stride(3)
if use_new_kcache_layout:
x = k_cache.size(-1)
stride_kcsplit_x, stride_kcs, stride_kcd = k_cache.stride()[2:]
num_warps = 8 if head_dim > 128 else 4 num_warps = 8 if head_dim > 128 else 4
grid = (bsz, num_kv_heads) grid = (bsz, num_kv_heads)
_copy_to_kvcache_seqlen1_kernel[grid]( _copy_to_kvcache_seqlen1_kernel[grid](
@ -223,8 +280,9 @@ def copy_kv_to_blocked_cache(
v.stride(2), v.stride(2),
k_cache.stride(0), k_cache.stride(0),
k_cache.stride(1), k_cache.stride(1),
k_cache.stride(2), stride_kcsplit_x,
k_cache.stride(3), stride_kcs,
stride_kcd,
v_cache.stride(0), v_cache.stride(0),
v_cache.stride(1), v_cache.stride(1),
v_cache.stride(2), v_cache.stride(2),
@ -233,5 +291,6 @@ def copy_kv_to_blocked_cache(
block_tables.stride(1), block_tables.stride(1),
block_size, block_size,
HEAD_DIM=head_dim, HEAD_DIM=head_dim,
KCACHE_X=x,
num_warps=num_warps, num_warps=num_warps,
) )

View File

@ -1,3 +1,4 @@
import warnings
from typing import Optional from typing import Optional
import torch import torch
@ -85,8 +86,8 @@ def rotary_embedding_kernel(
mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
) )
handle_k = cur_head_idx % KV_GROUP_NUM == 0 handle_kv = cur_head_idx % KV_GROUP_NUM == 0
if handle_k: if handle_kv:
k_head_idx = cur_head_idx // KV_GROUP_NUM k_head_idx = cur_head_idx // KV_GROUP_NUM
off_k0 = ( off_k0 = (
tokens_range[:, None, None] * k_token_stride tokens_range[:, None, None] * k_token_stride
@ -385,6 +386,7 @@ def decoding_fused_rotary_embedding_kernel(
v_cache, v_cache,
BLOCK_TABLES, BLOCK_TABLES,
context_lengths, context_lengths,
x,
q_token_stride, q_token_stride,
q_head_stride, q_head_stride,
k_token_stride, k_token_stride,
@ -392,10 +394,15 @@ def decoding_fused_rotary_embedding_kernel(
head_dim_stride, head_dim_stride,
cos_token_stride, cos_token_stride,
cos_stride, cos_stride,
cache_b_stride, kcb_stride,
cache_h_stride, kch_stride,
cache_bs_stride, kcsplit_x_stride,
cache_d_stride, kcs_stride,
kcd_stride,
vcb_stride,
vch_stride,
vcs_stride,
vcd_stride,
bts_stride, bts_stride,
btb_stride, btb_stride,
block_size, block_size,
@ -424,8 +431,8 @@ def decoding_fused_rotary_embedding_kernel(
tl.store(q + off_q0, out_q0) tl.store(q + off_q0, out_q0)
tl.store(q + off_q1, out_q1) tl.store(q + off_q1, out_q1)
handle_k = cur_head_idx % KV_GROUP_NUM == 0 handle_kv = cur_head_idx % KV_GROUP_NUM == 0
if handle_k: if handle_kv:
cur_k_head_idx = cur_head_idx // KV_GROUP_NUM cur_k_head_idx = cur_head_idx // KV_GROUP_NUM
off_kv = cur_token_idx * k_token_stride + cur_k_head_idx * k_head_stride off_kv = cur_token_idx * k_token_stride + cur_k_head_idx * k_head_stride
off_k0 = off_kv + dim_range0 * head_dim_stride off_k0 = off_kv + dim_range0 * head_dim_stride
@ -443,17 +450,18 @@ def decoding_fused_rotary_embedding_kernel(
last_block_idx = past_kv_seq_len // block_size last_block_idx = past_kv_seq_len // block_size
block_ids = tl.load(BLOCK_TABLES + cur_token_idx * bts_stride + last_block_idx * btb_stride) block_ids = tl.load(BLOCK_TABLES + cur_token_idx * bts_stride + last_block_idx * btb_stride)
offsets_in_last_block = past_kv_seq_len % block_size offsets_in_last_block = past_kv_seq_len % block_size
offsets_cache_base = block_ids * kcb_stride + cur_k_head_idx * kch_stride
k_range0 = ( k_range0 = (
block_ids * cache_b_stride offsets_cache_base
+ cur_k_head_idx * cache_h_stride + offsets_in_last_block * kcs_stride
+ offsets_in_last_block * cache_bs_stride + (dim_range0 // x) * kcsplit_x_stride
+ dim_range0 * cache_d_stride + (dim_range0 % x) * kcd_stride
) )
k_range1 = ( k_range1 = (
block_ids * cache_b_stride offsets_cache_base
+ cur_k_head_idx * cache_h_stride + offsets_in_last_block * kcs_stride
+ offsets_in_last_block * cache_bs_stride + (dim_range1 // x) * kcsplit_x_stride
+ dim_range1 * cache_d_stride + (dim_range1 % x) * kcd_stride
) )
tl.store(k_cache + k_range0, out_k0) tl.store(k_cache + k_range0, out_k0)
tl.store(k_cache + k_range1, out_k1) tl.store(k_cache + k_range1, out_k1)
@ -461,10 +469,10 @@ def decoding_fused_rotary_embedding_kernel(
off_v = off_kv + dim_range * head_dim_stride off_v = off_kv + dim_range * head_dim_stride
loaded_v = tl.load(v + off_v) loaded_v = tl.load(v + off_v)
v_range = ( v_range = (
block_ids * cache_b_stride block_ids * vcb_stride
+ cur_k_head_idx * cache_h_stride + cur_k_head_idx * vch_stride
+ offsets_in_last_block * cache_bs_stride + offsets_in_last_block * vcs_stride
+ dim_range * cache_d_stride + dim_range * vcd_stride
) )
tl.store(v_cache + v_range, loaded_v) tl.store(v_cache + v_range, loaded_v)
@ -532,6 +540,7 @@ def rotary_embedding(
num_warps=num_warps, num_warps=num_warps,
) )
else: else:
warnings.warn("Fused rotary embedding Triton kernel will be deprecated as the new kcache layout is supported")
grid = (triton.next_power_of_2(q_head_num), q_total_tokens) grid = (triton.next_power_of_2(q_head_num), q_total_tokens)
fused_rotary_embedding_kernel_v2[grid]( fused_rotary_embedding_kernel_v2[grid](
q, q,
@ -573,6 +582,7 @@ def decoding_fused_rotary_embedding(
v_cache: Optional[torch.Tensor] = None, v_cache: Optional[torch.Tensor] = None,
block_tables: Optional[torch.Tensor] = None, block_tables: Optional[torch.Tensor] = None,
kv_lengths: Optional[torch.Tensor] = None, kv_lengths: Optional[torch.Tensor] = None,
use_new_kcache_layout: bool = False,
): ):
""" """
Args: Args:
@ -588,8 +598,6 @@ def decoding_fused_rotary_embedding(
""" """
q_total_tokens, q_head_num, head_dim = q.shape q_total_tokens, q_head_num, head_dim = q.shape
assert q.size(0) == k.size(0) == v.size(0) assert q.size(0) == k.size(0) == v.size(0)
assert k.size(1) == v.size(1)
assert k_cache.size(-1) == v_cache.size(-1)
if head_dim >= 512: if head_dim >= 512:
num_warps = 16 num_warps = 16
@ -597,18 +605,22 @@ def decoding_fused_rotary_embedding(
num_warps = 8 num_warps = 8
else: else:
num_warps = 4 num_warps = 4
q_token_stride = q.stride(0)
q_head_stride = q.stride(1)
head_dim_stride = q.stride(2)
k_token_stride = k.stride(0)
k_head_stride = k.stride(1)
k_head_num = k.size(1) k_head_num = k.size(1)
kv_group_num = q_head_num // k_head_num kv_group_num = q_head_num // k_head_num
cos_token_stride = cos.stride(0) # For KCache and VCache with the same layout
cos_stride = cos.stride(1) x = head_dim
kcsplit_x_stride, kcs_stride, kcd_stride = 0, k_cache.stride(2), k_cache.stride(3)
# For KCache layout [num_blocks, num_kv_heads, head_dim//x, block_size, x]
if use_new_kcache_layout:
assert (
k_cache.dim() == 5
and k_cache.shape[1] == v_cache.shape[1]
and k_cache.shape[2] * k_cache.shape[4] == v_cache.shape[3]
), f"Invalid KCache shape {k_cache.shape} and VCache shape {v_cache.shape}"
x = k_cache.size(-1)
kcsplit_x_stride, kcs_stride, kcd_stride = k_cache.stride()[-3:]
grid = (q_head_num, q_total_tokens) grid = (q_head_num, q_total_tokens)
decoding_fused_rotary_embedding_kernel[grid]( decoding_fused_rotary_embedding_kernel[grid](
q, q,
@ -620,17 +632,23 @@ def decoding_fused_rotary_embedding(
v_cache, v_cache,
block_tables, block_tables,
kv_lengths, kv_lengths,
q_token_stride, x,
q_head_stride, q.stride(0),
k_token_stride, q.stride(1),
k_head_stride, k.stride(0),
head_dim_stride, k.stride(1),
cos_token_stride, q.stride(2),
cos_stride, cos.stride(0),
cos.stride(1),
k_cache.stride(0), k_cache.stride(0),
k_cache.stride(1), k_cache.stride(1),
k_cache.stride(2), kcsplit_x_stride,
k_cache.stride(3), kcs_stride,
kcd_stride,
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(3),
block_tables.stride(0), block_tables.stride(0),
block_tables.stride(1), block_tables.stride(1),
k_cache.size(-2), k_cache.size(-2),

View File

@ -6,6 +6,7 @@ from tests.test_infer.test_ops.triton.kernel_utils import (
convert_kv_unpad_to_padded, convert_kv_unpad_to_padded,
create_attention_mask, create_attention_mask,
generate_caches_and_block_tables_v2, generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
torch_attn_ref, torch_attn_ref,
) )
from tests.test_infer.test_ops.triton.test_decoding_attn import prepare_data from tests.test_infer.test_ops.triton.test_decoding_attn import prepare_data
@ -29,9 +30,9 @@ configs = [
x_vals=[2**i for i in range(8, 14)], x_vals=[2**i for i in range(8, 14)],
# x_vals=[x for x in range(256, 8192, 256)], # x_vals=[x for x in range(256, 8192, 256)],
line_arg="provider", line_arg="provider",
line_vals=["torch", "triton"], line_vals=["torch", "triton", "triton_new_kcache_layout"],
line_names=["Torch", "Triton"], line_names=["Torch", "Triton", "Triton New KCache Layout"],
styles=[("red", "-"), ("blue", "-")], styles=[("red", "-"), ("blue", "-"), ("yellow", "-")],
ylabel="ms", ylabel="ms",
plot_name=f"decoding-block_size-{BLOCK_SIZE}-batch{BATCH}", plot_name=f"decoding-block_size-{BLOCK_SIZE}-batch{BATCH}",
args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1}, args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1},
@ -62,6 +63,14 @@ def bench_kernel(
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device
) )
max_seq_len_in_b = kv_lengths.max().item() # for random lengths max_seq_len_in_b = kv_lengths.max().item() # for random lengths
# the maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
sm_scale = 1.0 / (HEAD_DIM**0.5)
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
mid_output = torch.empty(
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
)
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if provider == "torch": if provider == "torch":
@ -81,19 +90,11 @@ def bench_kernel(
HEAD_DIM, HEAD_DIM,
) )
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
if provider == "triton": elif provider == "triton":
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
) )
block_tables = block_tables.to(device=device) block_tables = block_tables.to(device=device)
# the maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
mid_output = torch.empty(
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
)
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
sm_scale = 1.0 / (HEAD_DIM**0.5)
fn = lambda: flash_decoding_attention( fn = lambda: flash_decoding_attention(
# Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1), # Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1),
# refer to attention forward in modeling. # refer to attention forward in modeling.
@ -111,6 +112,29 @@ def bench_kernel(
kv_group_num=kv_group_num, kv_group_num=kv_group_num,
) # [bsz, 1, num_heads, head_dim] ) # [bsz, 1, num_heads, head_dim]
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
elif provider == "triton_new_kcache_layout":
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
block_tables = block_tables.to(device=device)
fn = lambda: flash_decoding_attention(
# Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1),
# refer to attention forward in modeling.
q.squeeze(2),
k_cache,
v_cache,
kv_lengths,
block_tables,
block_size,
max_seq_len_in_b,
output,
mid_output,
mid_output_lse,
sm_scale=sm_scale,
kv_group_num=kv_group_num,
use_new_kcache_layout=True,
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
return ms, min_ms, max_ms return ms, min_ms, max_ms

View File

@ -24,18 +24,20 @@ configs = [
x_vals=[2**i for i in range(4, 11)], x_vals=[2**i for i in range(4, 11)],
line_arg="provider", line_arg="provider",
line_vals=[ line_vals=[
"no_fused_triton_rotary_emb_func", "triton_rotary_emb_func",
"fused_triton_rotary_emb_func", "triton_fused_rotary_emb_func",
"no_fused_cuda_rotary_emb_func", "triton_fused_rotary_emb_func_new_kcache_layout",
"fused_cuda_rotary_emb_func", "cuda_rotary_emb_func",
"cuda_fused_rotary_emb_func",
], ],
line_names=[ line_names=[
"no_fused_triton_rotary_emb_func", "triton_rotary_emb_func",
"fused_triton_rotary_emb_func", "triton_fused_rotary_emb_func",
"no_fused_cuda_rotary_emb_func", "triton_fused_rotary_emb_func(new layout)",
"fused_cuda_rotary_emb_func", "cuda_rotary_emb_func",
"cuda_fused_rotary_emb_func",
], ],
styles=[("red", "-"), ("blue", "-"), ("green", "-"), ("yellow", "-")], styles=[("red", "-"), ("blue", "-"), ("purple", "-"), ("green", "-"), ("yellow", "-")],
ylabel="ms", ylabel="ms",
plot_name=f"rotary_emb-batch-{BATCH}", plot_name=f"rotary_emb-batch-{BATCH}",
args={"num_kv_heads": 16}, args={"num_kv_heads": 16},
@ -91,31 +93,44 @@ def benchmark_rotary_emb(
kv_seq_lengths = past_kv_seq_lengths + 1 kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda") block_tables = block_tables.to(device="cuda")
if provider == "no_fused_triton_rotary_emb_func": quantiles = [0.5, 0.2, 0.8]
if provider == "triton_rotary_emb_func":
fn = lambda: [ fn = lambda: [
rotary_embedding(new_q, new_k, cos, sin), rotary_embedding(new_q, new_k, cos, sin),
copy_kv_to_blocked_cache( copy_kv_to_blocked_cache(
new_k, new_v, k_cache, v_cache, kv_lengths=kv_seq_lengths, block_tables=block_tables new_k, new_v, k_cache, v_cache, kv_lengths=kv_seq_lengths, block_tables=block_tables
), ),
] ]
elif provider == "fused_triton_rotary_emb_func": elif provider == "triton_fused_rotary_emb_func":
fn = lambda: decoding_fused_rotary_embedding( fn = lambda: decoding_fused_rotary_embedding(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths
) )
elif provider == "no_fused_cuda_rotary_emb_func": elif provider == "triton_fused_rotary_emb_func_new_kcache_layout":
x = 16 // torch.tensor([], dtype=dtype).element_size()
kcache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x)
k_cache = torch.zeros(size=kcache_shape, dtype=dtype, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v3(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
block_tables = block_tables.to(device="cuda")
fn = lambda: decoding_fused_rotary_embedding(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths, use_new_kcache_layout=True
)
elif provider == "cuda_rotary_emb_func":
fn = lambda: [ fn = lambda: [
inference_ops.rotary_embedding(new_q, new_k, cos, sin, True), inference_ops.rotary_embedding(new_q, new_k, cos, sin, True),
inference_ops.decode_kv_cache_memcpy(new_k, new_v, new_k_cache, v_cache, kv_seq_lengths, block_tables), inference_ops.decode_kv_cache_memcpy(new_k, new_v, new_k_cache, v_cache, kv_seq_lengths, block_tables),
] ]
elif provider == "fused_cuda_rotary_emb_func": elif provider == "cuda_fused_rotary_emb_func":
fn = lambda: inference_ops.rotary_embedding_and_cache_copy( fn = lambda: inference_ops.rotary_embedding_and_cache_copy(
new_q, new_k, new_v, cos, sin, new_k_cache, v_cache, kv_seq_lengths, block_tables, True new_q, new_k, new_v, cos, sin, new_k_cache, v_cache, kv_seq_lengths, block_tables, True
) )
else: else:
raise ValueError("Undefined provider") raise ValueError("Undefined provider")
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=quantiles)
return ms return ms, min_ms, max_ms
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -14,7 +14,7 @@ except ImportError:
inference_ops = InferenceOpsLoader().load() inference_ops = InferenceOpsLoader().load()
HEAD_DIM = 4 HEAD_DIM = 128
BATCH = 16 BATCH = 16
BLOCK_SIZE = 32 BLOCK_SIZE = 32
SAME_LEN = True SAME_LEN = True
@ -25,9 +25,9 @@ configs = [
x_names=["KV_SEQ_LEN"], x_names=["KV_SEQ_LEN"],
x_vals=[2**i for i in range(8, 13)], x_vals=[2**i for i in range(8, 13)],
line_arg="provider", line_arg="provider",
line_vals=["torch_copy_func", "triton_copy_func", "cuda_copy_func"], line_vals=["torch_copy_func", "triton_copy_func", "triton_new_kcache_layout", "cuda_copy_func"],
line_names=["torch_copy_func", "triton_copy_func", "cuda_copy_func"], line_names=["torch_copy_func", "triton_copy_func", "triton_new_kcache_layout", "cuda_copy_func"],
styles=[("red", "-"), ("blue", "-"), ("green", "-")], styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
ylabel="ms", ylabel="ms",
plot_name=f"kvcache_copy_decoding_stage-batch-{BATCH}", plot_name=f"kvcache_copy_decoding_stage-batch-{BATCH}",
args={"bsz": BATCH, "block_size": 16, "max_seq_len": 8192, "num_kv_heads": 16, "same_context_len": True}, args={"bsz": BATCH, "block_size": 16, "max_seq_len": 8192, "num_kv_heads": 16, "same_context_len": True},
@ -45,7 +45,7 @@ def benchmark_kvcache_copy(
num_kv_heads: int, num_kv_heads: int,
same_context_len: bool, same_context_len: bool,
): ):
dtype = torch.float32 dtype = torch.float16
device = get_current_device() device = get_current_device()
assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len" assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"
@ -63,11 +63,18 @@ def benchmark_kvcache_copy(
) )
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
# TODO copy_to_cache needs to support copying both k and v at the same time in the future.
if provider == "torch_copy_func": if provider == "torch_copy_func":
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding") fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
elif provider == "triton_copy_func": elif provider == "triton_copy_func":
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables) fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
elif provider == "triton_new_kcache_layout":
# NOTE New kcache layout (num_blocks, num_kv_heads, head_dim // x, block_size, x) to be applied
x = 16 // torch.tensor([], dtype=dtype).element_size()
k_cache_shape = (bsz * max_seq_len // block_size, num_kv_heads, HEAD_DIM // x, block_size, x)
k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=device) # update k_cache layout
fn = lambda: copy_kv_to_blocked_cache(
new_k, new_v, k_cache, v_cache, context_lengths, block_tables, use_new_kcache_layout=True
)
elif provider == "cuda_copy_func": elif provider == "cuda_copy_func":
_, _, k_cache, _, _, _, _, _, _ = prepare_data_new_kcache_layout( _, _, k_cache, _, _, _, _, _, _ = prepare_data_new_kcache_layout(
bsz, num_kv_heads, block_size, max_seq_len // block_size, context_lengths - 1, device, dtype bsz, num_kv_heads, block_size, max_seq_len // block_size, context_lengths - 1, device, dtype

View File

@ -10,6 +10,7 @@ from tests.test_infer.test_ops.triton.kernel_utils import (
convert_kv_unpad_to_padded, convert_kv_unpad_to_padded,
create_attention_mask, create_attention_mask,
generate_caches_and_block_tables_v2, generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
torch_attn_ref, torch_attn_ref,
) )
from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
@ -75,6 +76,7 @@ def prepare_data(
@pytest.mark.parametrize("same_context_len", [True, False]) @pytest.mark.parametrize("same_context_len", [True, False])
@pytest.mark.parametrize("q_len", [1, 5]) @pytest.mark.parametrize("q_len", [1, 5])
@pytest.mark.parametrize("use_alibi_slopes", [True, False]) @pytest.mark.parametrize("use_alibi_slopes", [True, False])
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
def test_flash_decoding( def test_flash_decoding(
bsz: int, bsz: int,
block_size: int, block_size: int,
@ -84,7 +86,15 @@ def test_flash_decoding(
same_context_len: bool, same_context_len: bool,
q_len: int, q_len: int,
use_alibi_slopes: bool, use_alibi_slopes: bool,
use_new_kcache_layout: bool,
): ):
if use_new_kcache_layout and use_alibi_slopes:
# TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one,
# the code (alibi kernel) will be refactored later to avoid code duplication, when
# the whole triton flow with new k cache layout has been supported and tested.
# And tests for the alibi kernel using new kcache layout will be added then.
pytest.skip("Alibi kernel does not support new kcache layout yet.")
torch.manual_seed(123) torch.manual_seed(123)
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.synchronize() torch.cuda.synchronize()
@ -127,9 +137,14 @@ def test_flash_decoding(
q, k_torch, v_torch, attention_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM q, k_torch, v_torch, attention_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
) )
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( if use_new_kcache_layout:
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
) k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
else:
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
block_tables = block_tables.to(device=device) block_tables = block_tables.to(device=device)
# The maximum block length splitted on kv should be the kv cache block size # The maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_kv_len_in_b + block_size - 1) // block_size kv_max_split_num = (max_kv_len_in_b + block_size - 1) // block_size
@ -165,6 +180,7 @@ def test_flash_decoding(
sm_scale=sm_scale, sm_scale=sm_scale,
kv_group_num=kv_group_num, kv_group_num=kv_group_num,
q_len=q_len, q_len=q_len,
use_new_kcache_layout=use_new_kcache_layout,
) # [bsz * q_len, num_heads, head_dim] ) # [bsz * q_len, num_heads, head_dim]
assert out_torch.shape == out_triton.shape assert out_torch.shape == out_triton.shape
@ -178,4 +194,4 @@ def test_flash_decoding(
if __name__ == "__main__": if __name__ == "__main__":
test_flash_decoding(16, 32, 32, 16, 1, True, 1, True) test_flash_decoding(16, 32, 32, 16, 1, True, 1, use_alibi_slopes=False, use_new_kcache_layout=True)

View File

@ -4,7 +4,11 @@ from packaging import version
from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token from tests.test_infer.test_ops.triton.kernel_utils import (
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
mock_alloc_single_token,
)
try: try:
import triton # noqa import triton # noqa
@ -30,6 +34,7 @@ def prepare_data(
n=1, n=1,
device="cuda", device="cuda",
dtype=torch.float16, dtype=torch.float16,
use_new_kcache_layout=False,
): ):
assert max_seq_len > n, "max_seq_len must be greater than n" assert max_seq_len > n, "max_seq_len must be greater than n"
@ -44,9 +49,14 @@ def prepare_data(
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2) k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( if use_new_kcache_layout:
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
) k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
)
else:
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
)
block_tables = block_tables.to(device=device) block_tables = block_tables.to(device=device)
new_k = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device) new_k = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device)
@ -66,8 +76,15 @@ def prepare_data(
@pytest.mark.parametrize("num_kv_heads", [16]) @pytest.mark.parametrize("num_kv_heads", [16])
@pytest.mark.parametrize("same_context_len", [True, False]) @pytest.mark.parametrize("same_context_len", [True, False])
@pytest.mark.parametrize("n_tokens", [1, 5]) @pytest.mark.parametrize("n_tokens", [1, 5])
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
def test_copy_kv_to_caches( def test_copy_kv_to_caches(
bsz: int, block_size: int, max_num_blocks_per_seq: int, num_kv_heads: int, same_context_len: bool, n_tokens: int bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
n_tokens: int,
use_new_kcache_layout: bool,
): ):
torch.manual_seed(123) torch.manual_seed(123)
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -89,6 +106,7 @@ def test_copy_kv_to_caches(
n_tokens, n_tokens,
device=device, device=device,
dtype=dtype, dtype=dtype,
use_new_kcache_layout=use_new_kcache_layout,
) )
k_source = new_k.view(-1, new_k.size(-2), new_k.size(-1)) k_source = new_k.view(-1, new_k.size(-2), new_k.size(-1))
v_source = new_v.view(-1, new_v.size(-2), new_v.size(-1)) v_source = new_v.view(-1, new_v.size(-2), new_v.size(-1))
@ -98,7 +116,9 @@ def test_copy_kv_to_caches(
offsets_in_block = past_kv_seq_lengths % block_size offsets_in_block = past_kv_seq_lengths % block_size
# Copy k (or v) to k (or v) cache # Copy k (or v) to k (or v) cache
copy_k_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables, n=n_tokens) copy_k_to_blocked_cache(
new_k, k_cache, kv_seq_lengths, block_tables, n=n_tokens, use_new_kcache_layout=use_new_kcache_layout
)
# Reshape target k from k cache to compare if matching with original tensor # Reshape target k from k cache to compare if matching with original tensor
# Mainly to handle cases of n_tokens > 1 # Mainly to handle cases of n_tokens > 1
k_target = [] k_target = []
@ -110,26 +130,39 @@ def test_copy_kv_to_caches(
while tokens_left > 0: while tokens_left > 0:
tokens_to_fill = min(block_size - offset, tokens_left) tokens_to_fill = min(block_size - offset, tokens_left)
curr_block_id = block_table[curr_kv_len // block_size] curr_block_id = block_table[curr_kv_len // block_size]
k_target.append(k_cache[curr_block_id, :, offset : offset + tokens_to_fill, :]) if use_new_kcache_layout:
k_target.append(k_cache[curr_block_id, :, :, offset : offset + tokens_to_fill, :])
else:
k_target.append(k_cache[curr_block_id, :, offset : offset + tokens_to_fill, :])
curr_kv_len += tokens_to_fill curr_kv_len += tokens_to_fill
tokens_left -= tokens_to_fill tokens_left -= tokens_to_fill
offset = 0 offset = 0
k_target = torch.concat(k_target, dim=1).transpose(0, 1).contiguous() # [bsz * n, num_kv_heads, head_dim] if use_new_kcache_layout:
k_target = torch.concat(k_target, dim=2).permute(2, 0, 1, 3).contiguous()
k_target = k_target.reshape(bsz * n_tokens, num_kv_heads, HEAD_DIM)
else:
k_target = torch.concat(k_target, dim=1).transpose(0, 1).contiguous() # [bsz * n, num_kv_heads, head_dim]
assert k_target.shape == k_source.shape assert k_target.shape == k_source.shape
assert torch.equal(k_target, k_source) assert torch.equal(k_target, k_source)
if n_tokens == 1: if n_tokens == 1:
# Copy k and v to k/v caches # Copy k and v to k/v caches
k_cache = k_cache_copy k_cache = k_cache_copy
copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables) copy_kv_to_blocked_cache(
k_target = k_cache_copy[target_block_ids, :, offsets_in_block, :] new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables, use_new_kcache_layout=use_new_kcache_layout
v_target = v_cache[target_block_ids, :, offsets_in_block, :] )
if use_new_kcache_layout:
k_target = k_cache[target_block_ids, :, :, offsets_in_block, :]
k_target = k_target.contiguous().reshape(bsz * n_tokens, num_kv_heads, HEAD_DIM)
else:
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
assert k_target.shape == k_source.shape assert k_target.shape == k_source.shape
assert torch.equal(k_target, k_source) assert torch.equal(k_target, k_source)
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
assert v_target.shape == v_source.shape assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source) assert torch.equal(v_target, v_source)
if __name__ == "__main__": if __name__ == "__main__":
test_copy_kv_to_caches(4, 32, 8, 16, True) test_copy_kv_to_caches(4, 32, 8, 16, True, n_tokens=1)

View File

@ -4,7 +4,10 @@ from packaging import version
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from colossalai.kernel.triton import decoding_fused_rotary_embedding from colossalai.kernel.triton import decoding_fused_rotary_embedding
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2 from tests.test_infer.test_ops.triton.kernel_utils import (
mock_alloc_block_table_and_kvcache_v2,
mock_alloc_block_table_and_kvcache_v3,
)
try: try:
import triton # noqa import triton # noqa
@ -36,7 +39,8 @@ def torch_rotary_emb(x, cos, sin):
@pytest.mark.parametrize("H", [32]) @pytest.mark.parametrize("H", [32])
@pytest.mark.parametrize("D", [64]) @pytest.mark.parametrize("D", [64])
@pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("dtype", [torch.float32])
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): @pytest.mark.parametrize("use_new_kcache_layout", [True, False])
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout):
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
# our crafted op equals to Transformers # our crafted op equals to Transformers
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D) x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
@ -57,28 +61,40 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
k_shape = (TOTAL_TOKENS, H, D) k_shape = (TOTAL_TOKENS, H, D)
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda") k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
cos_shape = (TOTAL_TOKENS, D // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, block_size, D)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
v = torch.randn_like(k) v = torch.randn_like(k)
v_cache = torch.zeros_like(k_cache)
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v2(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda") new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
new_q = torch.randn_like(new_k) new_q = torch.randn_like(new_k)
new_v = torch.randn_like(new_k) new_v = torch.randn_like(new_k)
cos_shape = (TOTAL_TOKENS, D // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
v_cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, block_size, D)
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device="cuda")
if use_new_kcache_layout:
x = 16 // torch.tensor([], dtype=dtype).element_size()
kcache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, D // x, block_size, x)
k_cache = torch.zeros(size=kcache_shape, dtype=dtype, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v3(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
else:
k_cache = torch.zeros_like(v_cache)
block_tables = mock_alloc_block_table_and_kvcache_v2(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
kv_seq_lengths = past_kv_seq_lengths + 1 kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda") block_tables = block_tables.to(device="cuda")
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
decoding_fused_rotary_embedding(new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths) decoding_fused_rotary_embedding(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths, use_new_kcache_layout
)
assert torch.allclose(new_q, q_ref, atol=1e-4, rtol=1e-4) assert torch.allclose(new_q, q_ref, atol=1e-4, rtol=1e-4)
if __name__ == "__main__": if __name__ == "__main__":
test_rotary_emb(4, 64, 32, 64, torch.float32) test_rotary_emb(4, 64, 32, 64, torch.float32, use_new_kcache_layout=True)