|
|
|
@ -9,13 +9,14 @@ import triton.language as tl
|
|
|
|
|
# Triton 2.1.0 |
|
|
|
|
@triton.jit |
|
|
|
|
def _flash_decoding_fwd_kernel( |
|
|
|
|
Q, # [batch_size, head_num, q_len(1), head_dim] |
|
|
|
|
Q, # [batch_size * q_len, head_num, head_dim] |
|
|
|
|
KCache, # [num_blocks, num_kv_heads, block_size, head_dim] |
|
|
|
|
VCache, # [num_blocks, num_kv_heads, block_size, head_dim] |
|
|
|
|
block_tables, # [batch_size, max_blocks_per_sequence] |
|
|
|
|
mid_o, # [batch_size, head_num, kv_split_num, head_dim] |
|
|
|
|
mid_o_lse, # [batch_size, head_num, kv_split_num] |
|
|
|
|
mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim] |
|
|
|
|
mid_o_lse, # [batch_size * q_len, head_num, kv_split_num] |
|
|
|
|
kv_seq_len, # [batch_size] |
|
|
|
|
q_len, |
|
|
|
|
batch_size, |
|
|
|
|
stride_qt, |
|
|
|
|
stride_qh, |
|
|
|
@ -39,44 +40,37 @@ def _flash_decoding_fwd_kernel(
|
|
|
|
|
BLOCK_SIZE: tl.constexpr, |
|
|
|
|
HEAD_DIM: tl.constexpr, |
|
|
|
|
): |
|
|
|
|
cur_seq_idx = tl.program_id(0) |
|
|
|
|
cur_token_idx = tl.program_id(0) |
|
|
|
|
cur_seq_idx = cur_token_idx // q_len |
|
|
|
|
if cur_seq_idx >= batch_size: |
|
|
|
|
return |
|
|
|
|
cur_head_idx = tl.program_id(1) |
|
|
|
|
block_start_kv = tl.program_id(2) # for splitting k/v |
|
|
|
|
|
|
|
|
|
cur_kv_head_idx = cur_head_idx // KV_GROUPS |
|
|
|
|
offsets_dmodel = tl.arange(0, HEAD_DIM) |
|
|
|
|
|
|
|
|
|
# NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same |
|
|
|
|
# TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE) |
|
|
|
|
# and then support calculating multiple kv cache blocks on an instance |
|
|
|
|
tl.static_assert(BLOCK_KV == BLOCK_SIZE) |
|
|
|
|
|
|
|
|
|
# get the current (kv) sequence length from provided context lengths tensor |
|
|
|
|
# get the current (kv) sequence length |
|
|
|
|
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) |
|
|
|
|
if block_start_kv * BLOCK_KV >= cur_kv_seq_len: |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
offsets_q = cur_seq_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd |
|
|
|
|
offsets_dmodel = tl.arange(0, HEAD_DIM) |
|
|
|
|
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd |
|
|
|
|
q = tl.load(Q + offsets_q) |
|
|
|
|
|
|
|
|
|
# block table for the current sequence |
|
|
|
|
block_table_ptr = block_tables + cur_seq_idx * stride_bts |
|
|
|
|
|
|
|
|
|
# actually current block table current block start idx |
|
|
|
|
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE) |
|
|
|
|
cur_bt_start_idx = block_start_kv |
|
|
|
|
cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb) |
|
|
|
|
|
|
|
|
|
if block_start_kv * BLOCK_KV >= cur_kv_seq_len: |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
# cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb) |
|
|
|
|
cur_block_id = tl.load(block_table_ptr + block_start_kv * stride_btb) |
|
|
|
|
cur_occupied_size = tl.where( |
|
|
|
|
(block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE |
|
|
|
|
) |
|
|
|
|
tl.device_assert(cur_occupied_size >= 0) |
|
|
|
|
|
|
|
|
|
cur_kv_head_idx = cur_head_idx // KV_GROUPS |
|
|
|
|
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh |
|
|
|
|
|
|
|
|
|
K_block_ptr = tl.make_block_ptr( |
|
|
|
|
base=KCache + offset_kvcache, |
|
|
|
|
shape=(cur_occupied_size, HEAD_DIM), |
|
|
|
@ -115,14 +109,14 @@ def _flash_decoding_fwd_kernel(
|
|
|
|
|
acc = acc / l |
|
|
|
|
|
|
|
|
|
offsets_mid_o = ( |
|
|
|
|
cur_seq_idx * stride_mid_ot |
|
|
|
|
cur_token_idx * stride_mid_ot |
|
|
|
|
+ cur_head_idx * stride_mid_oh |
|
|
|
|
+ block_start_kv * stride_mid_ob |
|
|
|
|
+ offsets_dmodel * stride_mid_od |
|
|
|
|
) |
|
|
|
|
tl.store(mid_o + offsets_mid_o, acc) |
|
|
|
|
offsets_mid_o_lse = ( |
|
|
|
|
cur_seq_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb |
|
|
|
|
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb |
|
|
|
|
) |
|
|
|
|
# logsumexp L^(j) = m^(j) + log(l^(j)) |
|
|
|
|
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l)) |
|
|
|
@ -135,6 +129,7 @@ def _flash_decoding_fwd_reduce_kernel(
|
|
|
|
|
mid_o_lse, # [batch_size, head_num, kv_split_num] |
|
|
|
|
O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim] |
|
|
|
|
kv_seq_len, |
|
|
|
|
q_len, |
|
|
|
|
batch_size, |
|
|
|
|
stride_mid_ot, |
|
|
|
|
stride_mid_oh, |
|
|
|
@ -149,7 +144,8 @@ def _flash_decoding_fwd_reduce_kernel(
|
|
|
|
|
BLOCK_KV: tl.constexpr, |
|
|
|
|
HEAD_DIM: tl.constexpr, |
|
|
|
|
): |
|
|
|
|
cur_seq_idx = tl.program_id(0) |
|
|
|
|
cur_token_idx = tl.program_id(0) |
|
|
|
|
cur_seq_idx = cur_token_idx // q_len |
|
|
|
|
if cur_seq_idx >= batch_size: |
|
|
|
|
return |
|
|
|
|
cur_head_idx = tl.program_id(1) |
|
|
|
@ -164,8 +160,8 @@ def _flash_decoding_fwd_reduce_kernel(
|
|
|
|
|
l = 0.0 # sum exp |
|
|
|
|
acc = tl.zeros([HEAD_DIM], dtype=tl.float32) |
|
|
|
|
|
|
|
|
|
offsets_mid_o = cur_seq_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel |
|
|
|
|
offset_mid_lse = cur_seq_idx * stride_o_lset + cur_head_idx * stride_o_lseh |
|
|
|
|
offsets_mid_o = cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel |
|
|
|
|
offset_mid_lse = cur_token_idx * stride_o_lset + cur_head_idx * stride_o_lseh |
|
|
|
|
for block_i in range(0, kv_split_num, 1): |
|
|
|
|
mid_o_block = tl.load(mid_o + offsets_mid_o + block_i * stride_mid_ob) |
|
|
|
|
lse = tl.load(mid_o_lse + offset_mid_lse + block_i * stride_o_lseb) |
|
|
|
@ -179,7 +175,7 @@ def _flash_decoding_fwd_reduce_kernel(
|
|
|
|
|
m_i = m_ij |
|
|
|
|
|
|
|
|
|
acc = acc / l |
|
|
|
|
offsets_O = cur_seq_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel |
|
|
|
|
offsets_O = cur_token_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel |
|
|
|
|
tl.store(O + offsets_O, acc.to(O.type.element_ty)) |
|
|
|
|
return |
|
|
|
|
|
|
|
|
@ -199,12 +195,14 @@ def flash_decoding_attention(
|
|
|
|
|
mid_output_lse: torch.Tensor = None, |
|
|
|
|
sm_scale: int = None, |
|
|
|
|
kv_group_num: int = 1, |
|
|
|
|
q_len: int = 1, |
|
|
|
|
): |
|
|
|
|
""" |
|
|
|
|
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
q (torch.Tensor): [bsz, num_heads, head_dim] |
|
|
|
|
q (torch.Tensor): [bsz * q_len, num_heads, head_dim] |
|
|
|
|
q_len > 1 only for verification process in speculative-decoding. |
|
|
|
|
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] |
|
|
|
|
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] |
|
|
|
|
kv_seq_len (torch.Tensor): [batch_size] |
|
|
|
@ -212,19 +210,25 @@ def flash_decoding_attention(
|
|
|
|
|
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence] |
|
|
|
|
max_seq_len_in_batch (int): Maximum sequence length in the batch. |
|
|
|
|
output (torch.Tensor): [bsz, num_heads * head_dim] |
|
|
|
|
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim] |
|
|
|
|
mid_output (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num, head_dim] |
|
|
|
|
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`. |
|
|
|
|
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num] |
|
|
|
|
q_len > 1 only for verification process in speculative-decoding. |
|
|
|
|
mid_output_lse (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num] |
|
|
|
|
Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`. |
|
|
|
|
q_len > 1 only for verification process in speculative-decoding. |
|
|
|
|
block_size (int): Size of each block in the blocked key/value cache. |
|
|
|
|
num_kv_group (int, optional): Number of key/value groups. Defaults to 1. |
|
|
|
|
q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens). |
|
|
|
|
Defaults to 1. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
Output tensor with shape [bsz, num_heads * head_dim] |
|
|
|
|
Output tensor with shape [bsz * q_len, num_heads * head_dim] |
|
|
|
|
""" |
|
|
|
|
q = q.squeeze() if q.dim() == 4 else q |
|
|
|
|
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}" |
|
|
|
|
bsz, num_heads, head_dim = q.shape |
|
|
|
|
n_tokens, num_heads, head_dim = q.shape |
|
|
|
|
assert n_tokens % q_len == 0, "Invalid q_len" |
|
|
|
|
bsz = n_tokens // q_len |
|
|
|
|
|
|
|
|
|
assert head_dim in {32, 64, 128, 256} |
|
|
|
|
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, ( |
|
|
|
@ -247,22 +251,31 @@ def flash_decoding_attention(
|
|
|
|
|
max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch |
|
|
|
|
# For compatibility (TODO revise modeling in future) |
|
|
|
|
kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV |
|
|
|
|
mid_output = ( |
|
|
|
|
torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device) |
|
|
|
|
if mid_output is None |
|
|
|
|
else mid_output |
|
|
|
|
) |
|
|
|
|
mid_output_lse = ( |
|
|
|
|
torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device) |
|
|
|
|
if mid_output_lse is None |
|
|
|
|
else mid_output_lse |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if mid_output is None: |
|
|
|
|
mid_output = torch.empty( |
|
|
|
|
(bsz * q_len, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device |
|
|
|
|
) |
|
|
|
|
if mid_output_lse is None: |
|
|
|
|
mid_output_lse = torch.empty((bsz * q_len, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device) |
|
|
|
|
if output is None: |
|
|
|
|
# A hack to prevent `view` operation in modeling |
|
|
|
|
output = torch.empty((bsz * q_len, num_heads * head_dim), dtype=q.dtype, device=q.device) |
|
|
|
|
|
|
|
|
|
assert ( |
|
|
|
|
mid_output.size(2) == mid_output_lse.size(2) >= kv_max_split_num |
|
|
|
|
), "Incompatible kv split number of intermediate output tensors" |
|
|
|
|
assert ( |
|
|
|
|
mid_output.size(0) == mid_output_lse.size(0) >= output.size(0) == n_tokens |
|
|
|
|
), f"Incompatible first dimension of output tensors" |
|
|
|
|
|
|
|
|
|
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton |
|
|
|
|
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred) |
|
|
|
|
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV)) |
|
|
|
|
output = torch.empty((bsz, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output |
|
|
|
|
|
|
|
|
|
grid = ( |
|
|
|
|
triton.next_power_of_2(bsz * q_len), |
|
|
|
|
num_heads, |
|
|
|
|
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV), |
|
|
|
|
) |
|
|
|
|
_flash_decoding_fwd_kernel[grid]( |
|
|
|
|
q, |
|
|
|
|
k_cache, |
|
|
|
@ -271,6 +284,7 @@ def flash_decoding_attention(
|
|
|
|
|
mid_output, |
|
|
|
|
mid_output_lse, |
|
|
|
|
kv_seq_len, |
|
|
|
|
q_len, |
|
|
|
|
bsz, |
|
|
|
|
q.stride(0), |
|
|
|
|
q.stride(1), |
|
|
|
@ -295,13 +309,13 @@ def flash_decoding_attention(
|
|
|
|
|
HEAD_DIM=head_dim, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
grid = (triton.next_power_of_2(bsz), num_heads) |
|
|
|
|
|
|
|
|
|
grid = (triton.next_power_of_2(bsz * q_len), num_heads) |
|
|
|
|
_flash_decoding_fwd_reduce_kernel[grid]( |
|
|
|
|
mid_output, |
|
|
|
|
mid_output_lse, |
|
|
|
|
output, |
|
|
|
|
kv_seq_len, |
|
|
|
|
q_len, |
|
|
|
|
bsz, |
|
|
|
|
mid_output.stride(0), |
|
|
|
|
mid_output.stride(1), |
|
|
|
|