mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix boundary check in batch (#5306)
parent
c647e00e3c
commit
af8359c430
|
@ -22,6 +22,7 @@ def _fwd_context_paged_attention_kernel(
|
|||
KCache,
|
||||
VCache,
|
||||
BLOCK_TABLES, # [num_seqs, max_blocks_per_sequence]
|
||||
batch_size,
|
||||
stride_qt,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
|
@ -49,6 +50,8 @@ def _fwd_context_paged_attention_kernel(
|
|||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_seq_idx = tl.program_id(0)
|
||||
if cur_seq_idx >= batch_size:
|
||||
return
|
||||
cur_head_idx = tl.program_id(1)
|
||||
block_start_m = tl.program_id(2) # Br, max_input_len // Block_M
|
||||
cur_kv_head_idx = cur_head_idx // KV_GROUPS
|
||||
|
@ -217,6 +220,8 @@ def context_attention_unpadded(
|
|||
assert block_size in {16, 32, 64, 128}
|
||||
BLOCK_M = BLOCK_N = block_size
|
||||
|
||||
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
|
||||
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
|
||||
grid = (triton.next_power_of_2(num_seqs), num_heads, triton.cdiv(max_seq_len, BLOCK_M))
|
||||
|
||||
_fwd_context_paged_attention_kernel[grid](
|
||||
|
@ -227,6 +232,7 @@ def context_attention_unpadded(
|
|||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
num_seqs,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
|
|
|
@ -16,6 +16,7 @@ def _flash_decoding_fwd_kernel(
|
|||
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
|
||||
mid_o_lse, # [batch_size, head_num, kv_split_num]
|
||||
kv_seq_len, # [batch_size]
|
||||
batch_size,
|
||||
stride_qt,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
|
@ -39,6 +40,8 @@ def _flash_decoding_fwd_kernel(
|
|||
HEAD_DIM: tl.constexpr,
|
||||
):
|
||||
cur_seq_idx = tl.program_id(0)
|
||||
if cur_seq_idx >= batch_size:
|
||||
return
|
||||
cur_head_idx = tl.program_id(1)
|
||||
block_start_kv = tl.program_id(2) # for splitting k/v
|
||||
|
||||
|
@ -132,6 +135,7 @@ def _flash_decoding_fwd_reduce_kernel(
|
|||
mid_o_lse, # [batch_size, head_num, kv_split_num]
|
||||
O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim]
|
||||
kv_seq_len,
|
||||
batch_size,
|
||||
stride_mid_ot,
|
||||
stride_mid_oh,
|
||||
stride_mid_ob,
|
||||
|
@ -147,6 +151,8 @@ def _flash_decoding_fwd_reduce_kernel(
|
|||
HEAD_DIM: tl.constexpr,
|
||||
):
|
||||
cur_seq_idx = tl.program_id(0)
|
||||
if cur_seq_idx >= batch_size:
|
||||
return
|
||||
cur_head_idx = tl.program_id(1)
|
||||
|
||||
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
|
||||
|
@ -251,6 +257,8 @@ def flash_decoding_attention(
|
|||
else mid_output_lse
|
||||
)
|
||||
|
||||
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
|
||||
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
|
||||
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
|
||||
_flash_decoding_fwd_kernel[grid](
|
||||
q,
|
||||
|
@ -260,6 +268,7 @@ def flash_decoding_attention(
|
|||
mid_output,
|
||||
mid_output_lse,
|
||||
kv_seq_len,
|
||||
bsz,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
|
@ -285,12 +294,14 @@ def flash_decoding_attention(
|
|||
|
||||
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) # already overlapped
|
||||
|
||||
grid = (bsz, num_heads)
|
||||
grid = (triton.next_power_of_2(bsz), num_heads)
|
||||
|
||||
_flash_decoding_fwd_reduce_kernel[grid](
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
output,
|
||||
kv_seq_len,
|
||||
bsz,
|
||||
mid_output.stride(0),
|
||||
mid_output.stride(1),
|
||||
mid_output.stride(2),
|
||||
|
|
Loading…
Reference in New Issue