From 5be590b99eb6c58c3aa809d453680139fdd2b9f7 Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Date: Fri, 26 Apr 2024 17:51:49 +0800 Subject: [PATCH] [kernel] Support new KCache Layout - Context Attention Triton Kernel (#5658) * add context attn triton kernel - new kcache layout * add benchmark triton * tiny revise * trivial - code style, comment --- .../kernel/triton/context_attn_unpad.py | 243 +++++++++++++++++- .../benchmark_context_attn_unpad.py | 28 +- .../triton/test_context_attn_unpad.py | 33 ++- 3 files changed, 291 insertions(+), 13 deletions(-) diff --git a/colossalai/kernel/triton/context_attn_unpad.py b/colossalai/kernel/triton/context_attn_unpad.py index a7b5242ff..e2fe6ab92 100644 --- a/colossalai/kernel/triton/context_attn_unpad.py +++ b/colossalai/kernel/triton/context_attn_unpad.py @@ -185,6 +185,184 @@ def _fwd_context_paged_attention_kernel( return +# Triton 2.1.0 +# TODO(yuanheng-zhao): This is a temporary dispatch to use the new layout for kcache +# merge `_fwd_context_paged_attention_kernel_v2` with `_fwd_context_paged_attention_kernel` later +# as the kcache layout has been supported in the whole triton flow. +@triton.jit +def _fwd_context_paged_attention_kernel_v2( + Q, + K, + V, + O, + KCache, # [num_blocks, num_kv_heads, head_dim // x, block_size, x] + VCache, # [num_blocks, num_kv_heads, block_size, head_dim] + BLOCK_TABLES, # [num_seqs, max_blocks_per_sequence] + batch_size, + stride_qt, + stride_qh, + stride_qd, + stride_kt, + stride_kh, + stride_kd, + stride_vt, + stride_vh, + stride_vd, + stride_ot, + stride_oh, + stride_od, + stride_cacheb, # v cache stride(0) - num_blocks + stride_cacheh, # v cache stride(1) - num_kv_heads + stride_cachebs, # v cache stride(2) - block_size + stride_cached, # v cache stride(3) - head_dim + stride_bts, + stride_btb, + context_lengths, + sm_scale, + KV_GROUPS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + KCACHE_X: tl.constexpr, # k stride on the second last dimension + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_seq_idx = tl.program_id(0) + if cur_seq_idx >= batch_size: + return + cur_head_idx = tl.program_id(1) + block_start_m = tl.program_id(2) # Br, max_input_len // Block_M + cur_kv_head_idx = cur_head_idx // KV_GROUPS + + # NOTE It requires BLOCK_M, BLOCK_N, and BLOCK_SIZE to be the same + tl.static_assert(BLOCK_M == BLOCK_N) + tl.static_assert(BLOCK_N == BLOCK_SIZE) + + # get the current sequence length from provided context lengths tensor + cur_seq_len = tl.load(context_lengths + cur_seq_idx) + # NOTE when talking to fused QKV and a nopadding context attention, + # we assume that the input Q/K/V is contiguous, and thus here `prev_seq_len_sum` + # could be considered as the start index of the current sequence. + # FIXME might want to explore better way to get the summation of prev seq lengths. + # `tl.sum(tensor[:end])` is invalid as tensor slice is not supported in triton. + prev_seq_len_sum = 0 + for i in range(0, cur_seq_idx): + prev_seq_len_sum += tl.load(context_lengths + i) + + offset_q = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh + offset_kv = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh + Q_block_ptr = tl.make_block_ptr( + base=Q + offset_q, + shape=(cur_seq_len, HEAD_DIM), + strides=(stride_qt, stride_qd), + offsets=(block_start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + offset_kv, + shape=(HEAD_DIM, cur_seq_len), + strides=(stride_kd, stride_kt), + offsets=(0, 0), + block_shape=(HEAD_DIM, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=V + offset_kv, + shape=(cur_seq_len, HEAD_DIM), + strides=(stride_vt, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=O + offset_q, + shape=(cur_seq_len, HEAD_DIM), + strides=(stride_ot, stride_od), + offsets=(block_start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + + # block table for the current sequence + block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts + # block indexes on block table (i.e. 0, 1, 2, ..., max_blocks_per_seq) + # Consider `block_start_m` as the logical block idx in the current block table, + # as we have BLOCK_M the same size as the block size. + cur_block_table_idx = block_start_m + cur_block_id = tl.load(block_table_ptr + cur_block_table_idx * stride_btb) + offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh + + offsets_m = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offsets_n = tl.arange(0, BLOCK_N) + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + if block_start_m * BLOCK_M >= cur_seq_len: + return + + Q_i = tl.load(Q_block_ptr, boundary_check=(1, 0)) + + for block_start_n in range(0, (block_start_m + 1) * BLOCK_M, BLOCK_N): + block_start_n = tl.multiple_of(block_start_n, BLOCK_N) + + k = tl.load(K_block_ptr, boundary_check=(0, 1)) + S_ij = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + S_ij += tl.dot(Q_i, k) + S_ij *= sm_scale + S_ij += tl.where(offsets_m[:, None] >= (block_start_n + offsets_n[None, :]), 0, float("-inf")) + + m_ij = tl.max(S_ij, 1) # rowmax(Sij) + m_ij = tl.maximum(m_i, m_ij) # m_ij + S_ij -= m_ij[:, None] + p_ij_hat = tl.exp(S_ij) + scale = tl.exp(m_i - m_ij) + l_ij = scale * l_i + tl.sum(p_ij_hat, 1) + acc = acc * scale[:, None] + + v = tl.load(V_block_ptr, boundary_check=(1, 0)) + p_ij_hat = p_ij_hat.to(v.type.element_ty) + + acc += tl.dot(p_ij_hat, v) + l_i = l_ij + m_i = m_ij + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + + acc = acc / l_i[:, None] + tl.store(O_block_ptr, acc.to(O.type.element_ty), boundary_check=(1, 0)) + + if cur_head_idx % KV_GROUPS == 0: + # Copy k to corresponding cache block + block_range = tl.arange(0, BLOCK_SIZE) + X_range = tl.arange(0, KCACHE_X) + # unroll the loop aggressively + for split_x in tl.static_range(HEAD_DIM // KCACHE_X): + offsets_dmodel_x_partion = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X) + offsets_k = K + offset_kv + offsets_dmodel_x_partion[None, :] * stride_kd + offsets_m[:, None] * stride_kt + k = tl.load(offsets_k, mask=offsets_m[:, None] < cur_seq_len, other=0.0) + # HACK: KCache must be contiguous in order to apply the following offsets calculation + offsets_kcache = ( + KCache + + offset_kvcache + + split_x * BLOCK_SIZE * KCACHE_X + + block_range[:, None] * KCACHE_X + + X_range[None, :] + ) + tl.store(offsets_kcache, k, mask=block_range[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE) + # Copy v to corresponding cache block + offsets_vd = tl.arange(0, HEAD_DIM) # offsets_dmodel + offsets_vt = block_start_m * BLOCK_N + offsets_n + offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_vd[:, None] * stride_vd + v = tl.load(offsets_v, mask=offsets_vt[None, :] < cur_seq_len, other=0.0) + offsets_vcache = ( + VCache + offset_kvcache + block_range[None, :] * stride_cachebs + offsets_vd[:, None] * stride_cached + ) + tl.store(offsets_vcache, v, mask=block_range[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE) + + return + + # Triton 2.1.0 @triton.jit def _alibi_fwd_context_paged_attention_kernel( @@ -375,8 +553,8 @@ def context_attention_unpadded( q: torch.Tensor, # [num_tokens, num_heads, head_dim] k: torch.Tensor, # [num_tokens, num_kv_heads, head_dim] v: torch.Tensor, # [num_tokens, num_kv_heads, head_dim] - k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size] - v_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size] + k_cache: torch.Tensor, # [num_blocks, num_kv_heads, block_size, head_dim] + v_cache: torch.Tensor, # [num_blocks, num_kv_heads, block_size, head_dim] context_lengths: torch.Tensor, # [num_seqs] block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence], block_size: int, @@ -384,12 +562,24 @@ def context_attention_unpadded( alibi_slopes: torch.Tensor = None, # [num_heads] max_seq_len: int = None, sm_scale: int = None, + # NOTE(yuanheng-zhao): the following flag is used to determine whether to use the new layout for kcache + # [num_blocks, num_kv_heads, head_dim // x, block_size, x] - must be contiguous + use_new_kcache_layout: bool = False, ): Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk == Lv assert Lk in {32, 64, 128, 256} assert q.shape[0] == k.shape[0] == v.shape[0] - assert k_cache.shape == v_cache.shape + k_cache_shape = k_cache.shape + v_cache_shape = v_cache.shape + if use_new_kcache_layout: + assert ( + len(k_cache_shape) == 5 + and k_cache_shape[1] == v_cache_shape[1] + and k_cache_shape[2] * k_cache_shape[4] == v_cache_shape[3] + ), f"Invalid KCache shape {k_cache_shape} and VCache shape {v_cache_shape}" + else: + assert k_cache_shape == v_cache_shape, f"Invalid KCache shape {k_cache_shape} and VCache shape {v_cache_shape}" assert context_lengths.shape[0] == block_tables.shape[0] num_tokens, num_heads, head_dim = q.shape @@ -413,6 +603,53 @@ def context_attention_unpadded( # To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred) grid = (triton.next_power_of_2(num_seqs), num_heads, triton.cdiv(max_seq_len, BLOCK_M)) + if use_new_kcache_layout: + # TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one, + # the code (alibi kernel) will be refactored later to avoid code duplication, when + # the whole triton flow with new k cache layout has been supported and tested. + assert ( + alibi_slopes is None + ), "Alibi Slopes will be supported with new kcache layout later when the whole triton flow is ready" + x = k_cache_shape[4] # Intuition: 16 // dtype_size + + _fwd_context_paged_attention_kernel_v2[grid]( + q, + k, + v, + output, + k_cache, + v_cache, + block_tables, + num_seqs, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + output.stride(0), + head_dim, + 1, + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + context_lengths, + sm_scale, + KV_GROUPS=num_kv_group, + BLOCK_SIZE=block_size, + HEAD_DIM=Lk, + KCACHE_X=x, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + return output + if alibi_slopes is not None: _alibi_fwd_context_paged_attention_kernel[grid]( q, diff --git a/examples/inference/benchmark_ops/benchmark_context_attn_unpad.py b/examples/inference/benchmark_ops/benchmark_context_attn_unpad.py index 40b64101c..498282ba3 100644 --- a/examples/inference/benchmark_ops/benchmark_context_attn_unpad.py +++ b/examples/inference/benchmark_ops/benchmark_context_attn_unpad.py @@ -24,9 +24,9 @@ configs = [ x_vals=[2**i for i in range(8, 13)], # x_vals=[x for x in range(256, 8192, 256)], line_arg="provider", - line_vals=["torch", "triton"], - line_names=["Torch", "Triton"], - styles=[("red", "-"), ("blue", "-")], + line_vals=["torch", "triton", "triton_new_klayout"], + line_names=["Torch", "Triton", "Triton_new_klayout"], + styles=[("red", "-"), ("blue", "-"), ("green", "-")], ylabel="ms", plot_name=f"context_attn-block_size-{BLOCK_SIZE}-batch{BATCH}", args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1}, @@ -98,13 +98,33 @@ def bench_kernel( HEAD_DIM, ) ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) - if provider == "triton": + elif provider == "triton": k_cache_triton = torch.zeros_like(k_cache_ref) v_cache_triton = torch.zeros_like(v_cache_ref) fn = lambda: context_attention_unpadded( q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size ) ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) + elif provider == "triton_new_klayout": + # NOTE New kcache layout (num_blocks, num_kv_heads, head_dim // x, block_size, x) + # to be applied around the cuda and triton kernels. + # Here we want to make sure it does not cause downgrade in performance. + x = 16 // torch.tensor([], dtype=dtype).element_size() + k_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, HEAD_DIM // x, block_size, x) + k_cache_triton = torch.zeros(size=k_cache_shape, dtype=dtype, device=device) + v_cache_triton = torch.zeros_like(v_cache_ref) + fn = lambda: context_attention_unpadded( + q_unpad, + k_unpad, + v_unpad, + k_cache_triton, + v_cache_triton, + context_lengths, + block_tables, + block_size, + use_new_kcache_layout=True, + ) + ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) return ms, min_ms, max_ms diff --git a/tests/test_infer/test_ops/triton/test_context_attn_unpad.py b/tests/test_infer/test_ops/triton/test_context_attn_unpad.py index 70f367c09..76785d530 100644 --- a/tests/test_infer/test_ops/triton/test_context_attn_unpad.py +++ b/tests/test_infer/test_ops/triton/test_context_attn_unpad.py @@ -5,7 +5,11 @@ from packaging import version from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes from colossalai.kernel.triton import context_attention_unpadded from colossalai.utils import get_current_device -from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref +from tests.test_infer.test_ops.triton.kernel_utils import ( + generate_caches_and_block_tables_v2, + generate_caches_and_block_tables_v3, + torch_attn_ref, +) try: import triton # noqa @@ -59,7 +63,7 @@ def torch_attn_unpad( mask = torch.tril(torch.ones(1, 1, seq_len, seq_len), diagonal=0).to(device=q.device) mask[mask == 0.0] = float("-inf") - if slopes != None: + if slopes is not None: alibi_mask = generate_alibi_mask(slopes, num_heads, seq_len, q.device) mask = mask + alibi_mask @@ -89,6 +93,7 @@ def torch_attn_unpad( @pytest.mark.parametrize("kv_group_num", [1, 2, 16]) @pytest.mark.parametrize("same_context_len", [True, False]) @pytest.mark.parametrize("use_alibi_slopes", [True, False]) +@pytest.mark.parametrize("use_new_kcache_layout", [True, False]) def test_context_attention( bsz: int, block_size: int, @@ -97,7 +102,15 @@ def test_context_attention( kv_group_num: int, same_context_len: bool, use_alibi_slopes: bool, + use_new_kcache_layout: bool, ): + if use_new_kcache_layout and use_alibi_slopes: + # TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one, + # the code (alibi kernel) will be refactored later to avoid code duplication, when + # the whole triton flow with new k cache layout has been supported and tested. + # And tests for the alibi kernel using new kcache layout will be added then. + return + torch.manual_seed(123) # It's necessary to clear cache here. torch.cuda.empty_cache() @@ -124,9 +137,16 @@ def test_context_attention( qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2) q_unpad = q_unpad.contiguous() - k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2( - k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device - ) + + if use_new_kcache_layout: + k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v3( + k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + ) + else: + k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2( + k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + ) + block_tables = block_tables.to(device=device) k_cache_triton = torch.zeros_like(k_cache_ref) v_cache_triton = torch.zeros_like(v_cache_ref) @@ -143,6 +163,7 @@ def test_context_attention( block_tables, block_size, alibi_slopes=alibi_slopes, + use_new_kcache_layout=use_new_kcache_layout, ) out_triton = out_triton.view(-1, num_heads, head_dim) @@ -155,4 +176,4 @@ def test_context_attention( if __name__ == "__main__": - test_context_attention(4, 32, 8, 16, 1, True, True) + test_context_attention(4, 32, 8, 16, 1, True, True, True)