[hotfix] fix boundary check in batch (#5306)

pull/5326/head
Yuanheng Zhao 2024-01-25 10:23:12 +08:00 committed by GitHub
parent c647e00e3c
commit af8359c430
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 1 deletions

View File

@ -22,6 +22,7 @@ def _fwd_context_paged_attention_kernel(
KCache,
VCache,
BLOCK_TABLES, # [num_seqs, max_blocks_per_sequence]
batch_size,
stride_qt,
stride_qh,
stride_qd,
@ -49,6 +50,8 @@ def _fwd_context_paged_attention_kernel(
BLOCK_N: tl.constexpr,
):
cur_seq_idx = tl.program_id(0)
if cur_seq_idx >= batch_size:
return
cur_head_idx = tl.program_id(1)
block_start_m = tl.program_id(2) # Br, max_input_len // Block_M
cur_kv_head_idx = cur_head_idx // KV_GROUPS
@ -217,6 +220,8 @@ def context_attention_unpadded(
assert block_size in {16, 32, 64, 128}
BLOCK_M = BLOCK_N = block_size
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
grid = (triton.next_power_of_2(num_seqs), num_heads, triton.cdiv(max_seq_len, BLOCK_M))
_fwd_context_paged_attention_kernel[grid](
@ -227,6 +232,7 @@ def context_attention_unpadded(
k_cache,
v_cache,
block_tables,
num_seqs,
q.stride(0),
q.stride(1),
q.stride(2),

View File

@ -16,6 +16,7 @@ def _flash_decoding_fwd_kernel(
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
mid_o_lse, # [batch_size, head_num, kv_split_num]
kv_seq_len, # [batch_size]
batch_size,
stride_qt,
stride_qh,
stride_qd,
@ -39,6 +40,8 @@ def _flash_decoding_fwd_kernel(
HEAD_DIM: tl.constexpr,
):
cur_seq_idx = tl.program_id(0)
if cur_seq_idx >= batch_size:
return
cur_head_idx = tl.program_id(1)
block_start_kv = tl.program_id(2) # for splitting k/v
@ -132,6 +135,7 @@ def _flash_decoding_fwd_reduce_kernel(
mid_o_lse, # [batch_size, head_num, kv_split_num]
O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim]
kv_seq_len,
batch_size,
stride_mid_ot,
stride_mid_oh,
stride_mid_ob,
@ -147,6 +151,8 @@ def _flash_decoding_fwd_reduce_kernel(
HEAD_DIM: tl.constexpr,
):
cur_seq_idx = tl.program_id(0)
if cur_seq_idx >= batch_size:
return
cur_head_idx = tl.program_id(1)
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
@ -251,6 +257,8 @@ def flash_decoding_attention(
else mid_output_lse
)
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
_flash_decoding_fwd_kernel[grid](
q,
@ -260,6 +268,7 @@ def flash_decoding_attention(
mid_output,
mid_output_lse,
kv_seq_len,
bsz,
q.stride(0),
q.stride(1),
q.stride(2),
@ -285,12 +294,14 @@ def flash_decoding_attention(
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) # already overlapped
grid = (bsz, num_heads)
grid = (triton.next_power_of_2(bsz), num_heads)
_flash_decoding_fwd_reduce_kernel[grid](
mid_output,
mid_output_lse,
output,
kv_seq_len,
bsz,
mid_output.stride(0),
mid_output.stride(1),
mid_output.stride(2),