mirror of https://github.com/hpcaitech/ColossalAI
[kernel] Support New KCache Layout - Triton Kernel (#5677)
* kvmemcpy triton for new kcache layout * revise tests for new kcache layout * naive triton flash decoding - new kcache layout * rotary triton kernel - new kcache layout * remove redundancy - triton decoding * remove redundancy - triton kvcache copy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5685/head
parent
9df016fc45
commit
537a3cbc4d
|
@ -338,8 +338,8 @@ def _fwd_context_paged_attention_kernel_v2(
|
|||
X_range = tl.arange(0, KCACHE_X)
|
||||
# unroll the loop aggressively
|
||||
for split_x in tl.static_range(HEAD_DIM // KCACHE_X):
|
||||
offsets_dmodel_x_partion = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X)
|
||||
offsets_k = K + offset_kv + offsets_dmodel_x_partion[None, :] * stride_kd + offsets_m[:, None] * stride_kt
|
||||
offsets_dmodel_x_partition = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X)
|
||||
offsets_k = K + offset_kv + offsets_dmodel_x_partition[None, :] * stride_kd + offsets_m[:, None] * stride_kt
|
||||
k = tl.load(offsets_k, mask=offsets_m[:, None] < cur_seq_len, other=0.0)
|
||||
# HACK: KCache must be contiguous in order to apply the following offsets calculation
|
||||
offsets_kcache = (
|
||||
|
|
|
@ -11,20 +11,29 @@ import triton.language as tl
|
|||
def _flash_decoding_fwd_kernel(
|
||||
Q, # [batch_size * q_len, head_num, head_dim]
|
||||
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
||||
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
||||
VCache, # [num_blocks, num_kv_heads, block_size, head_dim],
|
||||
# or [num_blocks, num_kv_heads, head_dim//x, block_size, x], depends on strides provided
|
||||
block_tables, # [batch_size, max_blocks_per_sequence]
|
||||
mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim]
|
||||
mid_o_lse, # [batch_size * q_len, head_num, kv_split_num]
|
||||
kv_seq_len, # [batch_size]
|
||||
q_len,
|
||||
batch_size,
|
||||
kv_group_num,
|
||||
x,
|
||||
sm_scale,
|
||||
stride_qt,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_cacheb,
|
||||
stride_cacheh,
|
||||
stride_cachebs,
|
||||
stride_cached,
|
||||
stride_kcb,
|
||||
stride_kch,
|
||||
stride_kcsplit_x,
|
||||
stride_kcs,
|
||||
stride_kcd,
|
||||
stride_vcb,
|
||||
stride_vch,
|
||||
stride_vcs,
|
||||
stride_vcd,
|
||||
stride_bts,
|
||||
stride_btb,
|
||||
stride_mid_ot,
|
||||
|
@ -34,8 +43,6 @@ def _flash_decoding_fwd_kernel(
|
|||
stride_mid_o_lset,
|
||||
stride_mid_o_lseh,
|
||||
stride_mid_o_lseb,
|
||||
sm_scale,
|
||||
KV_GROUPS: tl.constexpr,
|
||||
BLOCK_KV: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
|
@ -57,10 +64,9 @@ def _flash_decoding_fwd_kernel(
|
|||
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
|
||||
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
|
||||
return
|
||||
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
|
||||
q = tl.load(Q + offsets_q)
|
||||
offsets_block = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
# block table for the current sequence
|
||||
block_table_ptr = block_tables + cur_seq_idx * stride_bts
|
||||
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
|
||||
|
@ -71,25 +77,25 @@ def _flash_decoding_fwd_kernel(
|
|||
)
|
||||
tl.device_assert(cur_occupied_size >= 0)
|
||||
|
||||
cur_kv_head_idx = cur_head_idx // KV_GROUPS
|
||||
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=KCache + offset_kvcache,
|
||||
shape=(cur_occupied_size, HEAD_DIM),
|
||||
strides=(stride_cachebs, stride_cached),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_SIZE, HEAD_DIM),
|
||||
order=(0, 1),
|
||||
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
|
||||
q = tl.load(Q + offsets_q)
|
||||
cur_kv_head_idx = cur_head_idx // kv_group_num
|
||||
offset_kvcache = cur_block_id * stride_kcb + cur_kv_head_idx * stride_kch
|
||||
offsets_k = (
|
||||
offset_kvcache
|
||||
+ (offsets_dmodel[None, :] // x) * stride_kcsplit_x
|
||||
+ (offsets_dmodel[None, :] % x) * stride_kcd
|
||||
+ offsets_block[:, None] * stride_kcs
|
||||
)
|
||||
k_cur_block = tl.load(KCache + offsets_k)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=VCache + offset_kvcache,
|
||||
shape=(cur_occupied_size, HEAD_DIM),
|
||||
strides=(stride_cachebs, stride_cached),
|
||||
strides=(stride_vcs, stride_vcd),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_SIZE, HEAD_DIM),
|
||||
order=(0, 1),
|
||||
)
|
||||
k_cur_block = tl.load(K_block_ptr)
|
||||
v_cur_block = tl.load(V_block_ptr)
|
||||
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
||||
# use block size of the paged/blocked kv cache
|
||||
|
@ -100,7 +106,7 @@ def _flash_decoding_fwd_kernel(
|
|||
# Refer to https://github.com/openai/triton/discussions/895
|
||||
S_ij += tl.sum(q[None, :] * k_cur_block, 1)
|
||||
S_ij *= sm_scale
|
||||
S_ij += tl.where(block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) < cur_kv_seq_len, 0, float("-inf"))
|
||||
S_ij += tl.where(block_start_kv * BLOCK_KV + offsets_block < cur_kv_seq_len, 0, float("-inf"))
|
||||
|
||||
m = tl.max(S_ij, 0)
|
||||
S_ij -= m
|
||||
|
@ -324,6 +330,7 @@ def flash_decoding_attention(
|
|||
sm_scale: int = None,
|
||||
kv_group_num: int = 1,
|
||||
q_len: int = 1, # NOTE alibi flash decoding does not support q_len > 1 at this moment.
|
||||
use_new_kcache_layout: bool = False,
|
||||
):
|
||||
"""
|
||||
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
|
||||
|
@ -349,6 +356,7 @@ def flash_decoding_attention(
|
|||
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
|
||||
q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens).
|
||||
Defaults to 1.
|
||||
use_new_kcache_layout (bool): Whether to use the new kcache layout. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Output tensor with shape [bsz * q_len, num_heads * head_dim]
|
||||
|
@ -400,13 +408,20 @@ def flash_decoding_attention(
|
|||
|
||||
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
|
||||
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
|
||||
grid = (
|
||||
grid = lambda META: (
|
||||
triton.next_power_of_2(bsz * q_len),
|
||||
num_heads,
|
||||
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV),
|
||||
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), META["BLOCK_KV"]),
|
||||
)
|
||||
|
||||
if alibi_slopes is not None:
|
||||
# TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one,
|
||||
# the code (alibi kernel) will be refactored later to avoid code duplication, when
|
||||
# the whole triton flow with new k cache layout has been supported and tested.
|
||||
assert (
|
||||
not use_new_kcache_layout
|
||||
), "Alibi Slopes will be supported with new kcache layout later when the whole triton flow is ready"
|
||||
|
||||
_alibi_flash_decoding_fwd_kernel[grid](
|
||||
q,
|
||||
k_cache,
|
||||
|
@ -441,6 +456,19 @@ def flash_decoding_attention(
|
|||
HEAD_DIM=head_dim,
|
||||
)
|
||||
else:
|
||||
# For KCache and VCache with the same layout
|
||||
x = head_dim
|
||||
kcsplit_x_stride, kcs_stride, kcd_stride = 0, k_cache.stride(2), k_cache.stride(3)
|
||||
# For KCache layout [num_blocks, num_kv_heads, head_dim//x, block_size, x]
|
||||
if use_new_kcache_layout:
|
||||
assert (
|
||||
k_cache.dim() == 5
|
||||
and k_cache.shape[1] == v_cache.shape[1]
|
||||
and k_cache.shape[2] * k_cache.shape[4] == v_cache.shape[3]
|
||||
), f"Invalid KCache shape {k_cache.shape} and VCache shape {v_cache.shape}"
|
||||
x = k_cache.size(-1)
|
||||
kcsplit_x_stride, kcs_stride, kcd_stride = k_cache.stride()[-3:]
|
||||
|
||||
_flash_decoding_fwd_kernel[grid](
|
||||
q,
|
||||
k_cache,
|
||||
|
@ -451,13 +479,21 @@ def flash_decoding_attention(
|
|||
kv_seq_len,
|
||||
q_len,
|
||||
bsz,
|
||||
kv_group_num,
|
||||
x,
|
||||
sm_scale,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
kcsplit_x_stride,
|
||||
kcs_stride,
|
||||
kcd_stride,
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
mid_output.stride(0),
|
||||
|
@ -467,8 +503,6 @@ def flash_decoding_attention(
|
|||
mid_output_lse.stride(0),
|
||||
mid_output_lse.stride(1),
|
||||
mid_output_lse.stride(2),
|
||||
sm_scale,
|
||||
KV_GROUPS=kv_group_num,
|
||||
BLOCK_KV=block_size,
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_DIM=head_dim,
|
||||
|
|
|
@ -4,56 +4,69 @@ import triton.language as tl
|
|||
|
||||
|
||||
# Triton 2.1.0
|
||||
# supports two types of cache layouts
|
||||
# 1. [num_blocks, num_kv_heads, block_size, head_dim]
|
||||
# 2. [num_blocks, num_kv_heads, head_dim // x, block_size, x]
|
||||
@triton.jit
|
||||
def _copy_to_kcache_seqlen_n_kernel(
|
||||
KV, # K or V
|
||||
KVCache, # KCache or VCache
|
||||
K, # K or V
|
||||
KCache, # [num_blocks, num_kv_heads, head_dim // x, block_size, x]
|
||||
BLOCK_TABLES,
|
||||
context_lengths,
|
||||
seq_lengths,
|
||||
stride_kt,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_cacheb,
|
||||
stride_cacheh,
|
||||
stride_cachebs,
|
||||
stride_cached,
|
||||
stride_kcb,
|
||||
stride_kch,
|
||||
stride_kcsplit_x,
|
||||
stride_kcs,
|
||||
stride_kcx,
|
||||
stride_bts,
|
||||
stride_btb,
|
||||
block_size,
|
||||
n,
|
||||
n_tokens,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
KCACHE_X: tl.constexpr,
|
||||
):
|
||||
# `n_tokens` is used to specify the number of tokens to copy for each sequence
|
||||
# When n_tokens > 1, tokens from different sequences are packed into the first dimension of the grid,
|
||||
# `seq_lengths` must be the lengths of sequences counting the number of tokens to copy
|
||||
# E.g. if n_tokens = 5, seq_lengths = [12, 15], then the already-copied position ids are [0-6, 0-9]
|
||||
# for the two sequences, respectively. And the position ids to be copied are [7-11, 9-14].
|
||||
# When n_tokens = 1, consider token idx as the sequence idx, since it's only used during regular decoding stage
|
||||
cur_token_idx = tl.program_id(0)
|
||||
cur_seq_idx = cur_token_idx // n
|
||||
cur_token_shift = cur_token_idx - (n * (cur_seq_idx + 1))
|
||||
# cur_token_shift = cur_token_idx - n * cur_seq_idx
|
||||
cur_seq_idx = cur_token_idx // n_tokens
|
||||
# `cur_token_shift` is only valid and functional when `n_tokens` > 1
|
||||
cur_token_shift = cur_token_idx - (n_tokens * (cur_seq_idx + 1))
|
||||
cur_kv_head_idx = tl.program_id(1)
|
||||
split_x_idx = tl.program_id(2)
|
||||
|
||||
past_kv_seq_len = tl.load(context_lengths + cur_seq_idx) + cur_token_shift
|
||||
past_kv_seq_len = tl.load(seq_lengths + cur_seq_idx) + cur_token_shift
|
||||
last_bt_block_idx = past_kv_seq_len // block_size
|
||||
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
|
||||
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
|
||||
offset_last_block = past_kv_seq_len % block_size
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
offsets_kv = cur_token_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
||||
kv = tl.load(KV + offsets_kv)
|
||||
offsets_kvcache = (
|
||||
block_id * stride_cacheb
|
||||
+ cur_kv_head_idx * stride_cacheh
|
||||
+ offset_last_block * stride_cachebs
|
||||
+ offsets_dmodel * stride_cached
|
||||
offsets_dmodel = split_x_idx * KCACHE_X + tl.arange(0, KCACHE_X)
|
||||
offsets_k = cur_token_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
||||
k = tl.load(K + offsets_k)
|
||||
offsets_kcache = (
|
||||
block_id * stride_kcb
|
||||
+ cur_kv_head_idx * stride_kch
|
||||
+ split_x_idx * stride_kcsplit_x
|
||||
+ offset_last_block * stride_kcs
|
||||
+ tl.arange(0, KCACHE_X)
|
||||
)
|
||||
tl.store(KVCache + offsets_kvcache, kv)
|
||||
tl.store(KCache + offsets_kcache, k)
|
||||
return
|
||||
|
||||
|
||||
# Triton 2.1.0
|
||||
@triton.jit
|
||||
def _copy_to_kvcache_seqlen1_kernel(
|
||||
K, # K
|
||||
V, # V
|
||||
KCache, # KCache
|
||||
VCache, # VCache
|
||||
K,
|
||||
V,
|
||||
KCache,
|
||||
VCache,
|
||||
BLOCK_TABLES,
|
||||
context_lengths,
|
||||
stride_kt,
|
||||
|
@ -62,18 +75,20 @@ def _copy_to_kvcache_seqlen1_kernel(
|
|||
stride_vt,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_cachekb,
|
||||
stride_cachekh,
|
||||
stride_cachekbs,
|
||||
stride_cachekd,
|
||||
stride_cachevb,
|
||||
stride_cachevh,
|
||||
stride_cachevbs,
|
||||
stride_cachevd,
|
||||
stride_kcb,
|
||||
stride_kch,
|
||||
stride_kcsplit_x,
|
||||
stride_kcs,
|
||||
stride_kcd,
|
||||
stride_vcb,
|
||||
stride_vch,
|
||||
stride_vcs,
|
||||
stride_vcd,
|
||||
stride_bts,
|
||||
stride_btb,
|
||||
block_size,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
KCACHE_X: tl.constexpr,
|
||||
):
|
||||
cur_seq_idx = tl.program_id(0)
|
||||
cur_kv_head_idx = tl.program_id(1)
|
||||
|
@ -83,33 +98,42 @@ def _copy_to_kvcache_seqlen1_kernel(
|
|||
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
|
||||
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
|
||||
offsets_in_last_block = past_kv_seq_len % block_size
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
offsets_k = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
||||
offsets_v = cur_seq_idx * stride_vt + cur_kv_head_idx * stride_vh + offsets_dmodel * stride_vd
|
||||
|
||||
range_x = tl.arange(0, KCACHE_X)
|
||||
offsets_dmodel_x_partition = tl.arange(0, KCACHE_X)
|
||||
|
||||
for split_x in tl.static_range(HEAD_DIM // KCACHE_X):
|
||||
offsets_dmodel_x_partition = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X)
|
||||
offsets_k = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel_x_partition * stride_kd
|
||||
k = tl.load(K + offsets_k)
|
||||
offsets_v = cur_seq_idx * stride_vt + cur_kv_head_idx * stride_vh + offsets_dmodel_x_partition * stride_vd
|
||||
v = tl.load(V + offsets_v)
|
||||
|
||||
offsets_kcache = (
|
||||
block_id * stride_cachekb
|
||||
+ cur_kv_head_idx * stride_cachekh
|
||||
+ offsets_in_last_block * stride_cachekbs
|
||||
+ offsets_dmodel * stride_cachekd
|
||||
block_id * stride_kcb
|
||||
+ cur_kv_head_idx * stride_kch
|
||||
+ split_x * stride_kcsplit_x
|
||||
+ offsets_in_last_block * stride_kcs
|
||||
+ range_x
|
||||
)
|
||||
offsets_vcache = (
|
||||
block_id * stride_cachevb
|
||||
+ cur_kv_head_idx * stride_cachevh
|
||||
+ offsets_in_last_block * stride_cachevbs
|
||||
+ offsets_dmodel * stride_cachevd
|
||||
)
|
||||
|
||||
tl.store(KCache + offsets_kcache, k)
|
||||
offsets_vcache = (
|
||||
block_id * stride_vcb
|
||||
+ cur_kv_head_idx * stride_vch
|
||||
+ offsets_in_last_block * stride_vcs
|
||||
+ offsets_dmodel_x_partition * stride_vcd
|
||||
)
|
||||
tl.store(VCache + offsets_vcache, v)
|
||||
return
|
||||
|
||||
|
||||
def copy_k_to_blocked_cache(
|
||||
k: torch.Tensor, k_cache: torch.Tensor, kv_lengths: torch.Tensor, block_tables: torch.Tensor, n: int = 1
|
||||
k: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
kv_lengths: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
n: int = 1,
|
||||
use_new_kcache_layout: bool = False,
|
||||
):
|
||||
"""
|
||||
Copy keys or values to the blocked key/value cache during decoding stage.
|
||||
|
@ -118,16 +142,17 @@ def copy_k_to_blocked_cache(
|
|||
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
|
||||
[bsz * n, num_kv_heads, head_dim] - Keys or values with seq len n
|
||||
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
|
||||
new KCache Layout [num_blocks, num_kv_heads, head_dim // x, block_size, x]
|
||||
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
|
||||
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
|
||||
n (int): Number of tokens to copy for each sequence. Default to 1.
|
||||
use_new_kcache_layout (bool): Whether to use the new layout for kcache. Default to False.
|
||||
"""
|
||||
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
|
||||
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
|
||||
|
||||
k = k.reshape(-1, k.size(-2), k.size(-1)) if k.dim() == 4 else k
|
||||
assert k.dim() == 3, f"Invalid k dim {k.dim()}"
|
||||
bsz, num_kv_heads, head_dim = k.shape
|
||||
if k.dim() == 4:
|
||||
k = k.reshape(-1, k.size(-2), k.size(-1))
|
||||
k_shape = k.shape
|
||||
bsz, num_kv_heads, head_dim = k_shape
|
||||
# NOTE when n > 1, the shape of k is [bsz * n, num_kv_heads, head_dim]
|
||||
if n > 1:
|
||||
assert bsz % n == 0, "Each sequence should have the same number of tokens to be copied"
|
||||
|
@ -139,12 +164,24 @@ def copy_k_to_blocked_cache(
|
|||
f" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}"
|
||||
)
|
||||
|
||||
k_cache_shape = k_cache.shape
|
||||
# Modify if the shape of kv cahce is changed.
|
||||
block_size = k_cache.size(-2)
|
||||
block_size = k_cache_shape[-2]
|
||||
|
||||
x = head_dim
|
||||
stride_kcsplit_x, stride_kcs, stride_kcd = 0, k_cache.stride(2), k_cache.stride(3)
|
||||
if use_new_kcache_layout:
|
||||
# when using kcache layout [num_blocks, num_kv_heads, head_dim // x, block_size, x]
|
||||
assert (
|
||||
len(k_cache_shape) == 5
|
||||
and k_cache_shape[1] == k_shape[1]
|
||||
and k_cache_shape[2] * k_cache_shape[4] == k_shape[2]
|
||||
), f"Incompatible k_cache shape {k_cache_shape} with k shape {k_shape}"
|
||||
x = k_cache.size(-1)
|
||||
stride_kcsplit_x, stride_kcs, stride_kcd = k_cache.stride()[2:]
|
||||
|
||||
num_warps = 8 if head_dim > 128 else 4
|
||||
|
||||
grid = (bsz * n, num_kv_heads)
|
||||
grid = (bsz * n, num_kv_heads, head_dim // x)
|
||||
_copy_to_kcache_seqlen_n_kernel[grid](
|
||||
k,
|
||||
k_cache,
|
||||
|
@ -155,13 +192,15 @@ def copy_k_to_blocked_cache(
|
|||
k.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
stride_kcsplit_x,
|
||||
stride_kcs,
|
||||
stride_kcd,
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
block_size,
|
||||
n=n,
|
||||
n_tokens=n,
|
||||
HEAD_DIM=head_dim,
|
||||
KCACHE_X=x,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
|
@ -173,6 +212,7 @@ def copy_kv_to_blocked_cache(
|
|||
v_cache: torch.Tensor,
|
||||
kv_lengths: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
use_new_kcache_layout: bool = False,
|
||||
):
|
||||
"""
|
||||
Copy keys or values to the blocked key/value cache during decoding stage.
|
||||
|
@ -184,19 +224,30 @@ def copy_kv_to_blocked_cache(
|
|||
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked value cache.
|
||||
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
|
||||
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
|
||||
use_new_kcache_layout (bool): Whether to use the new layout for kcache. Default to False.
|
||||
"""
|
||||
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
|
||||
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
|
||||
k_cache_shape = k_cache.shape
|
||||
v_cache_shape = v_cache.shape
|
||||
|
||||
if use_new_kcache_layout:
|
||||
assert (
|
||||
len(k_cache_shape) == 5
|
||||
and k_cache_shape[1] == v_cache_shape[1]
|
||||
and k_cache_shape[2] * k_cache_shape[4] == v_cache_shape[3]
|
||||
), f"Invalid KCache shape {k_cache_shape} and VCache shape {v_cache_shape}"
|
||||
else:
|
||||
assert k.size(-1) == k_cache_shape[-1], "Incompatible head dim"
|
||||
assert (
|
||||
k_cache_shape == v_cache_shape
|
||||
), f"Incompatible KCache shape {k_cache_shape} and VCache shape {v_cache_shape}"
|
||||
assert v.size(-1) == v_cache_shape[-1], "Incompatible head dim"
|
||||
|
||||
k = k.squeeze(1) if k.dim() == 4 else k
|
||||
assert k.dim() == 3, f"Incompatible k dim {k.dim()}"
|
||||
|
||||
assert v.size(-1) == v_cache.size(-1), "Incompatible head dim"
|
||||
assert v.dtype == v_cache.dtype, "Expected consistent dtype for tensor and cache."
|
||||
v = v.squeeze(1) if v.dim() == 4 else v
|
||||
assert v.dim() == 3, f"Incompatible v dim {v.dim()}"
|
||||
|
||||
bsz, num_kv_heads, head_dim = k.shape
|
||||
|
||||
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
||||
f"Got incompatible batch size (number of seqs):\n"
|
||||
f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; "
|
||||
|
@ -206,6 +257,12 @@ def copy_kv_to_blocked_cache(
|
|||
# Modify if the shape of kv cahce is changed.
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
x = head_dim
|
||||
stride_kcsplit_x, stride_kcs, stride_kcd = 0, k_cache.stride(2), k_cache.stride(3)
|
||||
if use_new_kcache_layout:
|
||||
x = k_cache.size(-1)
|
||||
stride_kcsplit_x, stride_kcs, stride_kcd = k_cache.stride()[2:]
|
||||
|
||||
num_warps = 8 if head_dim > 128 else 4
|
||||
grid = (bsz, num_kv_heads)
|
||||
_copy_to_kvcache_seqlen1_kernel[grid](
|
||||
|
@ -223,8 +280,9 @@ def copy_kv_to_blocked_cache(
|
|||
v.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
stride_kcsplit_x,
|
||||
stride_kcs,
|
||||
stride_kcd,
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
|
@ -233,5 +291,6 @@ def copy_kv_to_blocked_cache(
|
|||
block_tables.stride(1),
|
||||
block_size,
|
||||
HEAD_DIM=head_dim,
|
||||
KCACHE_X=x,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
@ -85,8 +86,8 @@ def rotary_embedding_kernel(
|
|||
mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
)
|
||||
|
||||
handle_k = cur_head_idx % KV_GROUP_NUM == 0
|
||||
if handle_k:
|
||||
handle_kv = cur_head_idx % KV_GROUP_NUM == 0
|
||||
if handle_kv:
|
||||
k_head_idx = cur_head_idx // KV_GROUP_NUM
|
||||
off_k0 = (
|
||||
tokens_range[:, None, None] * k_token_stride
|
||||
|
@ -385,6 +386,7 @@ def decoding_fused_rotary_embedding_kernel(
|
|||
v_cache,
|
||||
BLOCK_TABLES,
|
||||
context_lengths,
|
||||
x,
|
||||
q_token_stride,
|
||||
q_head_stride,
|
||||
k_token_stride,
|
||||
|
@ -392,10 +394,15 @@ def decoding_fused_rotary_embedding_kernel(
|
|||
head_dim_stride,
|
||||
cos_token_stride,
|
||||
cos_stride,
|
||||
cache_b_stride,
|
||||
cache_h_stride,
|
||||
cache_bs_stride,
|
||||
cache_d_stride,
|
||||
kcb_stride,
|
||||
kch_stride,
|
||||
kcsplit_x_stride,
|
||||
kcs_stride,
|
||||
kcd_stride,
|
||||
vcb_stride,
|
||||
vch_stride,
|
||||
vcs_stride,
|
||||
vcd_stride,
|
||||
bts_stride,
|
||||
btb_stride,
|
||||
block_size,
|
||||
|
@ -424,8 +431,8 @@ def decoding_fused_rotary_embedding_kernel(
|
|||
tl.store(q + off_q0, out_q0)
|
||||
tl.store(q + off_q1, out_q1)
|
||||
|
||||
handle_k = cur_head_idx % KV_GROUP_NUM == 0
|
||||
if handle_k:
|
||||
handle_kv = cur_head_idx % KV_GROUP_NUM == 0
|
||||
if handle_kv:
|
||||
cur_k_head_idx = cur_head_idx // KV_GROUP_NUM
|
||||
off_kv = cur_token_idx * k_token_stride + cur_k_head_idx * k_head_stride
|
||||
off_k0 = off_kv + dim_range0 * head_dim_stride
|
||||
|
@ -443,17 +450,18 @@ def decoding_fused_rotary_embedding_kernel(
|
|||
last_block_idx = past_kv_seq_len // block_size
|
||||
block_ids = tl.load(BLOCK_TABLES + cur_token_idx * bts_stride + last_block_idx * btb_stride)
|
||||
offsets_in_last_block = past_kv_seq_len % block_size
|
||||
offsets_cache_base = block_ids * kcb_stride + cur_k_head_idx * kch_stride
|
||||
k_range0 = (
|
||||
block_ids * cache_b_stride
|
||||
+ cur_k_head_idx * cache_h_stride
|
||||
+ offsets_in_last_block * cache_bs_stride
|
||||
+ dim_range0 * cache_d_stride
|
||||
offsets_cache_base
|
||||
+ offsets_in_last_block * kcs_stride
|
||||
+ (dim_range0 // x) * kcsplit_x_stride
|
||||
+ (dim_range0 % x) * kcd_stride
|
||||
)
|
||||
k_range1 = (
|
||||
block_ids * cache_b_stride
|
||||
+ cur_k_head_idx * cache_h_stride
|
||||
+ offsets_in_last_block * cache_bs_stride
|
||||
+ dim_range1 * cache_d_stride
|
||||
offsets_cache_base
|
||||
+ offsets_in_last_block * kcs_stride
|
||||
+ (dim_range1 // x) * kcsplit_x_stride
|
||||
+ (dim_range1 % x) * kcd_stride
|
||||
)
|
||||
tl.store(k_cache + k_range0, out_k0)
|
||||
tl.store(k_cache + k_range1, out_k1)
|
||||
|
@ -461,10 +469,10 @@ def decoding_fused_rotary_embedding_kernel(
|
|||
off_v = off_kv + dim_range * head_dim_stride
|
||||
loaded_v = tl.load(v + off_v)
|
||||
v_range = (
|
||||
block_ids * cache_b_stride
|
||||
+ cur_k_head_idx * cache_h_stride
|
||||
+ offsets_in_last_block * cache_bs_stride
|
||||
+ dim_range * cache_d_stride
|
||||
block_ids * vcb_stride
|
||||
+ cur_k_head_idx * vch_stride
|
||||
+ offsets_in_last_block * vcs_stride
|
||||
+ dim_range * vcd_stride
|
||||
)
|
||||
tl.store(v_cache + v_range, loaded_v)
|
||||
|
||||
|
@ -532,6 +540,7 @@ def rotary_embedding(
|
|||
num_warps=num_warps,
|
||||
)
|
||||
else:
|
||||
warnings.warn("Fused rotary embedding Triton kernel will be deprecated as the new kcache layout is supported")
|
||||
grid = (triton.next_power_of_2(q_head_num), q_total_tokens)
|
||||
fused_rotary_embedding_kernel_v2[grid](
|
||||
q,
|
||||
|
@ -573,6 +582,7 @@ def decoding_fused_rotary_embedding(
|
|||
v_cache: Optional[torch.Tensor] = None,
|
||||
block_tables: Optional[torch.Tensor] = None,
|
||||
kv_lengths: Optional[torch.Tensor] = None,
|
||||
use_new_kcache_layout: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -588,8 +598,6 @@ def decoding_fused_rotary_embedding(
|
|||
"""
|
||||
q_total_tokens, q_head_num, head_dim = q.shape
|
||||
assert q.size(0) == k.size(0) == v.size(0)
|
||||
assert k.size(1) == v.size(1)
|
||||
assert k_cache.size(-1) == v_cache.size(-1)
|
||||
|
||||
if head_dim >= 512:
|
||||
num_warps = 16
|
||||
|
@ -597,18 +605,22 @@ def decoding_fused_rotary_embedding(
|
|||
num_warps = 8
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
q_token_stride = q.stride(0)
|
||||
q_head_stride = q.stride(1)
|
||||
head_dim_stride = q.stride(2)
|
||||
|
||||
k_token_stride = k.stride(0)
|
||||
k_head_stride = k.stride(1)
|
||||
k_head_num = k.size(1)
|
||||
kv_group_num = q_head_num // k_head_num
|
||||
|
||||
cos_token_stride = cos.stride(0)
|
||||
cos_stride = cos.stride(1)
|
||||
# For KCache and VCache with the same layout
|
||||
x = head_dim
|
||||
kcsplit_x_stride, kcs_stride, kcd_stride = 0, k_cache.stride(2), k_cache.stride(3)
|
||||
# For KCache layout [num_blocks, num_kv_heads, head_dim//x, block_size, x]
|
||||
if use_new_kcache_layout:
|
||||
assert (
|
||||
k_cache.dim() == 5
|
||||
and k_cache.shape[1] == v_cache.shape[1]
|
||||
and k_cache.shape[2] * k_cache.shape[4] == v_cache.shape[3]
|
||||
), f"Invalid KCache shape {k_cache.shape} and VCache shape {v_cache.shape}"
|
||||
x = k_cache.size(-1)
|
||||
kcsplit_x_stride, kcs_stride, kcd_stride = k_cache.stride()[-3:]
|
||||
|
||||
grid = (q_head_num, q_total_tokens)
|
||||
decoding_fused_rotary_embedding_kernel[grid](
|
||||
q,
|
||||
|
@ -620,17 +632,23 @@ def decoding_fused_rotary_embedding(
|
|||
v_cache,
|
||||
block_tables,
|
||||
kv_lengths,
|
||||
q_token_stride,
|
||||
q_head_stride,
|
||||
k_token_stride,
|
||||
k_head_stride,
|
||||
head_dim_stride,
|
||||
cos_token_stride,
|
||||
cos_stride,
|
||||
x,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
q.stride(2),
|
||||
cos.stride(0),
|
||||
cos.stride(1),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
kcsplit_x_stride,
|
||||
kcs_stride,
|
||||
kcd_stride,
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
k_cache.size(-2),
|
||||
|
|
|
@ -6,6 +6,7 @@ from tests.test_infer.test_ops.triton.kernel_utils import (
|
|||
convert_kv_unpad_to_padded,
|
||||
create_attention_mask,
|
||||
generate_caches_and_block_tables_v2,
|
||||
generate_caches_and_block_tables_v3,
|
||||
torch_attn_ref,
|
||||
)
|
||||
from tests.test_infer.test_ops.triton.test_decoding_attn import prepare_data
|
||||
|
@ -29,9 +30,9 @@ configs = [
|
|||
x_vals=[2**i for i in range(8, 14)],
|
||||
# x_vals=[x for x in range(256, 8192, 256)],
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "triton"],
|
||||
line_names=["Torch", "Triton"],
|
||||
styles=[("red", "-"), ("blue", "-")],
|
||||
line_vals=["torch", "triton", "triton_new_kcache_layout"],
|
||||
line_names=["Torch", "Triton", "Triton New KCache Layout"],
|
||||
styles=[("red", "-"), ("blue", "-"), ("yellow", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"decoding-block_size-{BLOCK_SIZE}-batch{BATCH}",
|
||||
args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1},
|
||||
|
@ -62,6 +63,14 @@ def bench_kernel(
|
|||
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device
|
||||
)
|
||||
max_seq_len_in_b = kv_lengths.max().item() # for random lengths
|
||||
# the maximum block length splitted on kv should be the kv cache block size
|
||||
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
||||
sm_scale = 1.0 / (HEAD_DIM**0.5)
|
||||
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
|
||||
mid_output = torch.empty(
|
||||
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
||||
)
|
||||
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "torch":
|
||||
|
@ -81,19 +90,11 @@ def bench_kernel(
|
|||
HEAD_DIM,
|
||||
)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
||||
if provider == "triton":
|
||||
elif provider == "triton":
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
block_tables = block_tables.to(device=device)
|
||||
# the maximum block length splitted on kv should be the kv cache block size
|
||||
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
||||
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
|
||||
mid_output = torch.empty(
|
||||
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
||||
)
|
||||
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
|
||||
sm_scale = 1.0 / (HEAD_DIM**0.5)
|
||||
fn = lambda: flash_decoding_attention(
|
||||
# Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1),
|
||||
# refer to attention forward in modeling.
|
||||
|
@ -111,6 +112,29 @@ def bench_kernel(
|
|||
kv_group_num=kv_group_num,
|
||||
) # [bsz, 1, num_heads, head_dim]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
||||
elif provider == "triton_new_kcache_layout":
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
|
||||
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
block_tables = block_tables.to(device=device)
|
||||
fn = lambda: flash_decoding_attention(
|
||||
# Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1),
|
||||
# refer to attention forward in modeling.
|
||||
q.squeeze(2),
|
||||
k_cache,
|
||||
v_cache,
|
||||
kv_lengths,
|
||||
block_tables,
|
||||
block_size,
|
||||
max_seq_len_in_b,
|
||||
output,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
kv_group_num=kv_group_num,
|
||||
use_new_kcache_layout=True,
|
||||
)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
||||
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
|
|
@ -24,18 +24,20 @@ configs = [
|
|||
x_vals=[2**i for i in range(4, 11)],
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"no_fused_triton_rotary_emb_func",
|
||||
"fused_triton_rotary_emb_func",
|
||||
"no_fused_cuda_rotary_emb_func",
|
||||
"fused_cuda_rotary_emb_func",
|
||||
"triton_rotary_emb_func",
|
||||
"triton_fused_rotary_emb_func",
|
||||
"triton_fused_rotary_emb_func_new_kcache_layout",
|
||||
"cuda_rotary_emb_func",
|
||||
"cuda_fused_rotary_emb_func",
|
||||
],
|
||||
line_names=[
|
||||
"no_fused_triton_rotary_emb_func",
|
||||
"fused_triton_rotary_emb_func",
|
||||
"no_fused_cuda_rotary_emb_func",
|
||||
"fused_cuda_rotary_emb_func",
|
||||
"triton_rotary_emb_func",
|
||||
"triton_fused_rotary_emb_func",
|
||||
"triton_fused_rotary_emb_func(new layout)",
|
||||
"cuda_rotary_emb_func",
|
||||
"cuda_fused_rotary_emb_func",
|
||||
],
|
||||
styles=[("red", "-"), ("blue", "-"), ("green", "-"), ("yellow", "-")],
|
||||
styles=[("red", "-"), ("blue", "-"), ("purple", "-"), ("green", "-"), ("yellow", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"rotary_emb-batch-{BATCH}",
|
||||
args={"num_kv_heads": 16},
|
||||
|
@ -91,31 +93,44 @@ def benchmark_rotary_emb(
|
|||
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||
block_tables = block_tables.to(device="cuda")
|
||||
|
||||
if provider == "no_fused_triton_rotary_emb_func":
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "triton_rotary_emb_func":
|
||||
fn = lambda: [
|
||||
rotary_embedding(new_q, new_k, cos, sin),
|
||||
copy_kv_to_blocked_cache(
|
||||
new_k, new_v, k_cache, v_cache, kv_lengths=kv_seq_lengths, block_tables=block_tables
|
||||
),
|
||||
]
|
||||
elif provider == "fused_triton_rotary_emb_func":
|
||||
elif provider == "triton_fused_rotary_emb_func":
|
||||
fn = lambda: decoding_fused_rotary_embedding(
|
||||
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths
|
||||
)
|
||||
elif provider == "no_fused_cuda_rotary_emb_func":
|
||||
elif provider == "triton_fused_rotary_emb_func_new_kcache_layout":
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
kcache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x)
|
||||
k_cache = torch.zeros(size=kcache_shape, dtype=dtype, device="cuda")
|
||||
block_tables = mock_alloc_block_table_and_kvcache_v3(
|
||||
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
||||
block_tables = block_tables.to(device="cuda")
|
||||
fn = lambda: decoding_fused_rotary_embedding(
|
||||
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths, use_new_kcache_layout=True
|
||||
)
|
||||
elif provider == "cuda_rotary_emb_func":
|
||||
fn = lambda: [
|
||||
inference_ops.rotary_embedding(new_q, new_k, cos, sin, True),
|
||||
inference_ops.decode_kv_cache_memcpy(new_k, new_v, new_k_cache, v_cache, kv_seq_lengths, block_tables),
|
||||
]
|
||||
elif provider == "fused_cuda_rotary_emb_func":
|
||||
elif provider == "cuda_fused_rotary_emb_func":
|
||||
fn = lambda: inference_ops.rotary_embedding_and_cache_copy(
|
||||
new_q, new_k, new_v, cos, sin, new_k_cache, v_cache, kv_seq_lengths, block_tables, True
|
||||
)
|
||||
else:
|
||||
raise ValueError("Undefined provider")
|
||||
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=quantiles)
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -14,7 +14,7 @@ except ImportError:
|
|||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
HEAD_DIM = 4
|
||||
HEAD_DIM = 128
|
||||
BATCH = 16
|
||||
BLOCK_SIZE = 32
|
||||
SAME_LEN = True
|
||||
|
@ -25,9 +25,9 @@ configs = [
|
|||
x_names=["KV_SEQ_LEN"],
|
||||
x_vals=[2**i for i in range(8, 13)],
|
||||
line_arg="provider",
|
||||
line_vals=["torch_copy_func", "triton_copy_func", "cuda_copy_func"],
|
||||
line_names=["torch_copy_func", "triton_copy_func", "cuda_copy_func"],
|
||||
styles=[("red", "-"), ("blue", "-"), ("green", "-")],
|
||||
line_vals=["torch_copy_func", "triton_copy_func", "triton_new_kcache_layout", "cuda_copy_func"],
|
||||
line_names=["torch_copy_func", "triton_copy_func", "triton_new_kcache_layout", "cuda_copy_func"],
|
||||
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"kvcache_copy_decoding_stage-batch-{BATCH}",
|
||||
args={"bsz": BATCH, "block_size": 16, "max_seq_len": 8192, "num_kv_heads": 16, "same_context_len": True},
|
||||
|
@ -45,7 +45,7 @@ def benchmark_kvcache_copy(
|
|||
num_kv_heads: int,
|
||||
same_context_len: bool,
|
||||
):
|
||||
dtype = torch.float32
|
||||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
|
||||
assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"
|
||||
|
@ -63,11 +63,18 @@ def benchmark_kvcache_copy(
|
|||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
# TODO copy_to_cache needs to support copying both k and v at the same time in the future.
|
||||
if provider == "torch_copy_func":
|
||||
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
|
||||
elif provider == "triton_copy_func":
|
||||
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
|
||||
elif provider == "triton_new_kcache_layout":
|
||||
# NOTE New kcache layout (num_blocks, num_kv_heads, head_dim // x, block_size, x) to be applied
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
k_cache_shape = (bsz * max_seq_len // block_size, num_kv_heads, HEAD_DIM // x, block_size, x)
|
||||
k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=device) # update k_cache layout
|
||||
fn = lambda: copy_kv_to_blocked_cache(
|
||||
new_k, new_v, k_cache, v_cache, context_lengths, block_tables, use_new_kcache_layout=True
|
||||
)
|
||||
elif provider == "cuda_copy_func":
|
||||
_, _, k_cache, _, _, _, _, _, _ = prepare_data_new_kcache_layout(
|
||||
bsz, num_kv_heads, block_size, max_seq_len // block_size, context_lengths - 1, device, dtype
|
||||
|
|
|
@ -10,6 +10,7 @@ from tests.test_infer.test_ops.triton.kernel_utils import (
|
|||
convert_kv_unpad_to_padded,
|
||||
create_attention_mask,
|
||||
generate_caches_and_block_tables_v2,
|
||||
generate_caches_and_block_tables_v3,
|
||||
torch_attn_ref,
|
||||
)
|
||||
from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
|
||||
|
@ -75,6 +76,7 @@ def prepare_data(
|
|||
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||
@pytest.mark.parametrize("q_len", [1, 5])
|
||||
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
|
||||
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
|
||||
def test_flash_decoding(
|
||||
bsz: int,
|
||||
block_size: int,
|
||||
|
@ -84,7 +86,15 @@ def test_flash_decoding(
|
|||
same_context_len: bool,
|
||||
q_len: int,
|
||||
use_alibi_slopes: bool,
|
||||
use_new_kcache_layout: bool,
|
||||
):
|
||||
if use_new_kcache_layout and use_alibi_slopes:
|
||||
# TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one,
|
||||
# the code (alibi kernel) will be refactored later to avoid code duplication, when
|
||||
# the whole triton flow with new k cache layout has been supported and tested.
|
||||
# And tests for the alibi kernel using new kcache layout will be added then.
|
||||
pytest.skip("Alibi kernel does not support new kcache layout yet.")
|
||||
|
||||
torch.manual_seed(123)
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
@ -127,6 +137,11 @@ def test_flash_decoding(
|
|||
q, k_torch, v_torch, attention_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
|
||||
)
|
||||
|
||||
if use_new_kcache_layout:
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
|
||||
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
else:
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
|
@ -165,6 +180,7 @@ def test_flash_decoding(
|
|||
sm_scale=sm_scale,
|
||||
kv_group_num=kv_group_num,
|
||||
q_len=q_len,
|
||||
use_new_kcache_layout=use_new_kcache_layout,
|
||||
) # [bsz * q_len, num_heads, head_dim]
|
||||
|
||||
assert out_torch.shape == out_triton.shape
|
||||
|
@ -178,4 +194,4 @@ def test_flash_decoding(
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_flash_decoding(16, 32, 32, 16, 1, True, 1, True)
|
||||
test_flash_decoding(16, 32, 32, 16, 1, True, 1, use_alibi_slopes=False, use_new_kcache_layout=True)
|
||||
|
|
|
@ -4,7 +4,11 @@ from packaging import version
|
|||
|
||||
from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
generate_caches_and_block_tables_v2,
|
||||
generate_caches_and_block_tables_v3,
|
||||
mock_alloc_single_token,
|
||||
)
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
|
@ -30,6 +34,7 @@ def prepare_data(
|
|||
n=1,
|
||||
device="cuda",
|
||||
dtype=torch.float16,
|
||||
use_new_kcache_layout=False,
|
||||
):
|
||||
assert max_seq_len > n, "max_seq_len must be greater than n"
|
||||
|
||||
|
@ -44,6 +49,11 @@ def prepare_data(
|
|||
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
|
||||
|
||||
if use_new_kcache_layout:
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
|
||||
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
|
||||
)
|
||||
else:
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
|
||||
)
|
||||
|
@ -66,8 +76,15 @@ def prepare_data(
|
|||
@pytest.mark.parametrize("num_kv_heads", [16])
|
||||
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||
@pytest.mark.parametrize("n_tokens", [1, 5])
|
||||
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
|
||||
def test_copy_kv_to_caches(
|
||||
bsz: int, block_size: int, max_num_blocks_per_seq: int, num_kv_heads: int, same_context_len: bool, n_tokens: int
|
||||
bsz: int,
|
||||
block_size: int,
|
||||
max_num_blocks_per_seq: int,
|
||||
num_kv_heads: int,
|
||||
same_context_len: bool,
|
||||
n_tokens: int,
|
||||
use_new_kcache_layout: bool,
|
||||
):
|
||||
torch.manual_seed(123)
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -89,6 +106,7 @@ def test_copy_kv_to_caches(
|
|||
n_tokens,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
use_new_kcache_layout=use_new_kcache_layout,
|
||||
)
|
||||
k_source = new_k.view(-1, new_k.size(-2), new_k.size(-1))
|
||||
v_source = new_v.view(-1, new_v.size(-2), new_v.size(-1))
|
||||
|
@ -98,7 +116,9 @@ def test_copy_kv_to_caches(
|
|||
offsets_in_block = past_kv_seq_lengths % block_size
|
||||
|
||||
# Copy k (or v) to k (or v) cache
|
||||
copy_k_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables, n=n_tokens)
|
||||
copy_k_to_blocked_cache(
|
||||
new_k, k_cache, kv_seq_lengths, block_tables, n=n_tokens, use_new_kcache_layout=use_new_kcache_layout
|
||||
)
|
||||
# Reshape target k from k cache to compare if matching with original tensor
|
||||
# Mainly to handle cases of n_tokens > 1
|
||||
k_target = []
|
||||
|
@ -110,26 +130,39 @@ def test_copy_kv_to_caches(
|
|||
while tokens_left > 0:
|
||||
tokens_to_fill = min(block_size - offset, tokens_left)
|
||||
curr_block_id = block_table[curr_kv_len // block_size]
|
||||
if use_new_kcache_layout:
|
||||
k_target.append(k_cache[curr_block_id, :, :, offset : offset + tokens_to_fill, :])
|
||||
else:
|
||||
k_target.append(k_cache[curr_block_id, :, offset : offset + tokens_to_fill, :])
|
||||
curr_kv_len += tokens_to_fill
|
||||
tokens_left -= tokens_to_fill
|
||||
offset = 0
|
||||
if use_new_kcache_layout:
|
||||
k_target = torch.concat(k_target, dim=2).permute(2, 0, 1, 3).contiguous()
|
||||
k_target = k_target.reshape(bsz * n_tokens, num_kv_heads, HEAD_DIM)
|
||||
else:
|
||||
k_target = torch.concat(k_target, dim=1).transpose(0, 1).contiguous() # [bsz * n, num_kv_heads, head_dim]
|
||||
|
||||
assert k_target.shape == k_source.shape
|
||||
assert torch.equal(k_target, k_source)
|
||||
|
||||
if n_tokens == 1:
|
||||
# Copy k and v to k/v caches
|
||||
k_cache = k_cache_copy
|
||||
copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)
|
||||
k_target = k_cache_copy[target_block_ids, :, offsets_in_block, :]
|
||||
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
|
||||
copy_kv_to_blocked_cache(
|
||||
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables, use_new_kcache_layout=use_new_kcache_layout
|
||||
)
|
||||
|
||||
if use_new_kcache_layout:
|
||||
k_target = k_cache[target_block_ids, :, :, offsets_in_block, :]
|
||||
k_target = k_target.contiguous().reshape(bsz * n_tokens, num_kv_heads, HEAD_DIM)
|
||||
else:
|
||||
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
|
||||
assert k_target.shape == k_source.shape
|
||||
assert torch.equal(k_target, k_source)
|
||||
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
|
||||
assert v_target.shape == v_source.shape
|
||||
assert torch.equal(v_target, v_source)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_copy_kv_to_caches(4, 32, 8, 16, True)
|
||||
test_copy_kv_to_caches(4, 32, 8, 16, True, n_tokens=1)
|
||||
|
|
|
@ -4,7 +4,10 @@ from packaging import version
|
|||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
|
||||
from colossalai.kernel.triton import decoding_fused_rotary_embedding
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
mock_alloc_block_table_and_kvcache_v2,
|
||||
mock_alloc_block_table_and_kvcache_v3,
|
||||
)
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
|
@ -36,7 +39,8 @@ def torch_rotary_emb(x, cos, sin):
|
|||
@pytest.mark.parametrize("H", [32])
|
||||
@pytest.mark.parametrize("D", [64])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
||||
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
|
||||
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout):
|
||||
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
|
||||
# our crafted op equals to Transformers
|
||||
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
|
||||
|
@ -57,28 +61,40 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
|||
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
||||
k_shape = (TOTAL_TOKENS, H, D)
|
||||
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
||||
cos_shape = (TOTAL_TOKENS, D // 2)
|
||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, block_size, D)
|
||||
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
||||
v = torch.randn_like(k)
|
||||
v_cache = torch.zeros_like(k_cache)
|
||||
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
|
||||
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
||||
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
|
||||
new_q = torch.randn_like(new_k)
|
||||
new_v = torch.randn_like(new_k)
|
||||
|
||||
cos_shape = (TOTAL_TOKENS, D // 2)
|
||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
|
||||
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
|
||||
v_cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, block_size, D)
|
||||
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device="cuda")
|
||||
|
||||
if use_new_kcache_layout:
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
kcache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, D // x, block_size, x)
|
||||
k_cache = torch.zeros(size=kcache_shape, dtype=dtype, device="cuda")
|
||||
block_tables = mock_alloc_block_table_and_kvcache_v3(
|
||||
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
else:
|
||||
k_cache = torch.zeros_like(v_cache)
|
||||
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
||||
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||
block_tables = block_tables.to(device="cuda")
|
||||
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
|
||||
|
||||
decoding_fused_rotary_embedding(new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths)
|
||||
decoding_fused_rotary_embedding(
|
||||
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths, use_new_kcache_layout
|
||||
)
|
||||
assert torch.allclose(new_q, q_ref, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rotary_emb(4, 64, 32, 64, torch.float32)
|
||||
test_rotary_emb(4, 64, 32, 64, torch.float32, use_new_kcache_layout=True)
|
||||
|
|
Loading…
Reference in New Issue