diff --git a/colossalai/kernel/triton/context_attn_unpad.py b/colossalai/kernel/triton/context_attn_unpad.py index e2fe6ab92..9c69c4125 100644 --- a/colossalai/kernel/triton/context_attn_unpad.py +++ b/colossalai/kernel/triton/context_attn_unpad.py @@ -338,8 +338,8 @@ def _fwd_context_paged_attention_kernel_v2( X_range = tl.arange(0, KCACHE_X) # unroll the loop aggressively for split_x in tl.static_range(HEAD_DIM // KCACHE_X): - offsets_dmodel_x_partion = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X) - offsets_k = K + offset_kv + offsets_dmodel_x_partion[None, :] * stride_kd + offsets_m[:, None] * stride_kt + offsets_dmodel_x_partition = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X) + offsets_k = K + offset_kv + offsets_dmodel_x_partition[None, :] * stride_kd + offsets_m[:, None] * stride_kt k = tl.load(offsets_k, mask=offsets_m[:, None] < cur_seq_len, other=0.0) # HACK: KCache must be contiguous in order to apply the following offsets calculation offsets_kcache = ( diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py index 200835ec3..2fb8231cc 100644 --- a/colossalai/kernel/triton/flash_decoding.py +++ b/colossalai/kernel/triton/flash_decoding.py @@ -11,20 +11,29 @@ import triton.language as tl def _flash_decoding_fwd_kernel( Q, # [batch_size * q_len, head_num, head_dim] KCache, # [num_blocks, num_kv_heads, block_size, head_dim] - VCache, # [num_blocks, num_kv_heads, block_size, head_dim] + VCache, # [num_blocks, num_kv_heads, block_size, head_dim], + # or [num_blocks, num_kv_heads, head_dim//x, block_size, x], depends on strides provided block_tables, # [batch_size, max_blocks_per_sequence] mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim] mid_o_lse, # [batch_size * q_len, head_num, kv_split_num] kv_seq_len, # [batch_size] q_len, batch_size, + kv_group_num, + x, + sm_scale, stride_qt, stride_qh, stride_qd, - stride_cacheb, - stride_cacheh, - stride_cachebs, - stride_cached, + stride_kcb, + stride_kch, + stride_kcsplit_x, + stride_kcs, + stride_kcd, + stride_vcb, + stride_vch, + stride_vcs, + stride_vcd, stride_bts, stride_btb, stride_mid_ot, @@ -34,8 +43,6 @@ def _flash_decoding_fwd_kernel( stride_mid_o_lset, stride_mid_o_lseh, stride_mid_o_lseb, - sm_scale, - KV_GROUPS: tl.constexpr, BLOCK_KV: tl.constexpr, BLOCK_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr, @@ -57,10 +64,9 @@ def _flash_decoding_fwd_kernel( cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off if block_start_kv * BLOCK_KV >= cur_kv_seq_len: return - offsets_dmodel = tl.arange(0, HEAD_DIM) - offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd - q = tl.load(Q + offsets_q) + offsets_block = tl.arange(0, BLOCK_SIZE) + # block table for the current sequence block_table_ptr = block_tables + cur_seq_idx * stride_bts # cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE) @@ -71,25 +77,25 @@ def _flash_decoding_fwd_kernel( ) tl.device_assert(cur_occupied_size >= 0) - cur_kv_head_idx = cur_head_idx // KV_GROUPS - offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh - K_block_ptr = tl.make_block_ptr( - base=KCache + offset_kvcache, - shape=(cur_occupied_size, HEAD_DIM), - strides=(stride_cachebs, stride_cached), - offsets=(0, 0), - block_shape=(BLOCK_SIZE, HEAD_DIM), - order=(0, 1), + offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd + q = tl.load(Q + offsets_q) + cur_kv_head_idx = cur_head_idx // kv_group_num + offset_kvcache = cur_block_id * stride_kcb + cur_kv_head_idx * stride_kch + offsets_k = ( + offset_kvcache + + (offsets_dmodel[None, :] // x) * stride_kcsplit_x + + (offsets_dmodel[None, :] % x) * stride_kcd + + offsets_block[:, None] * stride_kcs ) + k_cur_block = tl.load(KCache + offsets_k) V_block_ptr = tl.make_block_ptr( base=VCache + offset_kvcache, shape=(cur_occupied_size, HEAD_DIM), - strides=(stride_cachebs, stride_cached), + strides=(stride_vcs, stride_vcd), offsets=(0, 0), block_shape=(BLOCK_SIZE, HEAD_DIM), order=(0, 1), ) - k_cur_block = tl.load(K_block_ptr) v_cur_block = tl.load(V_block_ptr) acc = tl.zeros([HEAD_DIM], dtype=tl.float32) # use block size of the paged/blocked kv cache @@ -100,7 +106,7 @@ def _flash_decoding_fwd_kernel( # Refer to https://github.com/openai/triton/discussions/895 S_ij += tl.sum(q[None, :] * k_cur_block, 1) S_ij *= sm_scale - S_ij += tl.where(block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) < cur_kv_seq_len, 0, float("-inf")) + S_ij += tl.where(block_start_kv * BLOCK_KV + offsets_block < cur_kv_seq_len, 0, float("-inf")) m = tl.max(S_ij, 0) S_ij -= m @@ -324,6 +330,7 @@ def flash_decoding_attention( sm_scale: int = None, kv_group_num: int = 1, q_len: int = 1, # NOTE alibi flash decoding does not support q_len > 1 at this moment. + use_new_kcache_layout: bool = False, ): """ Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage. @@ -349,6 +356,7 @@ def flash_decoding_attention( num_kv_group (int, optional): Number of key/value groups. Defaults to 1. q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens). Defaults to 1. + use_new_kcache_layout (bool): Whether to use the new kcache layout. Defaults to False. Returns: Output tensor with shape [bsz * q_len, num_heads * head_dim] @@ -400,13 +408,20 @@ def flash_decoding_attention( # NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton # To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred) - grid = ( + grid = lambda META: ( triton.next_power_of_2(bsz * q_len), num_heads, - triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV), + triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), META["BLOCK_KV"]), ) if alibi_slopes is not None: + # TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one, + # the code (alibi kernel) will be refactored later to avoid code duplication, when + # the whole triton flow with new k cache layout has been supported and tested. + assert ( + not use_new_kcache_layout + ), "Alibi Slopes will be supported with new kcache layout later when the whole triton flow is ready" + _alibi_flash_decoding_fwd_kernel[grid]( q, k_cache, @@ -441,6 +456,19 @@ def flash_decoding_attention( HEAD_DIM=head_dim, ) else: + # For KCache and VCache with the same layout + x = head_dim + kcsplit_x_stride, kcs_stride, kcd_stride = 0, k_cache.stride(2), k_cache.stride(3) + # For KCache layout [num_blocks, num_kv_heads, head_dim//x, block_size, x] + if use_new_kcache_layout: + assert ( + k_cache.dim() == 5 + and k_cache.shape[1] == v_cache.shape[1] + and k_cache.shape[2] * k_cache.shape[4] == v_cache.shape[3] + ), f"Invalid KCache shape {k_cache.shape} and VCache shape {v_cache.shape}" + x = k_cache.size(-1) + kcsplit_x_stride, kcs_stride, kcd_stride = k_cache.stride()[-3:] + _flash_decoding_fwd_kernel[grid]( q, k_cache, @@ -451,13 +479,21 @@ def flash_decoding_attention( kv_seq_len, q_len, bsz, + kv_group_num, + x, + sm_scale, q.stride(0), q.stride(1), q.stride(2), k_cache.stride(0), k_cache.stride(1), - k_cache.stride(2), - k_cache.stride(3), + kcsplit_x_stride, + kcs_stride, + kcd_stride, + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), block_tables.stride(0), block_tables.stride(1), mid_output.stride(0), @@ -467,8 +503,6 @@ def flash_decoding_attention( mid_output_lse.stride(0), mid_output_lse.stride(1), mid_output_lse.stride(2), - sm_scale, - KV_GROUPS=kv_group_num, BLOCK_KV=block_size, BLOCK_SIZE=block_size, HEAD_DIM=head_dim, diff --git a/colossalai/kernel/triton/kvcache_copy.py b/colossalai/kernel/triton/kvcache_copy.py index 871f1f6d8..77397b5cb 100644 --- a/colossalai/kernel/triton/kvcache_copy.py +++ b/colossalai/kernel/triton/kvcache_copy.py @@ -4,56 +4,69 @@ import triton.language as tl # Triton 2.1.0 +# supports two types of cache layouts +# 1. [num_blocks, num_kv_heads, block_size, head_dim] +# 2. [num_blocks, num_kv_heads, head_dim // x, block_size, x] @triton.jit def _copy_to_kcache_seqlen_n_kernel( - KV, # K or V - KVCache, # KCache or VCache + K, # K or V + KCache, # [num_blocks, num_kv_heads, head_dim // x, block_size, x] BLOCK_TABLES, - context_lengths, + seq_lengths, stride_kt, stride_kh, stride_kd, - stride_cacheb, - stride_cacheh, - stride_cachebs, - stride_cached, + stride_kcb, + stride_kch, + stride_kcsplit_x, + stride_kcs, + stride_kcx, stride_bts, stride_btb, block_size, - n, + n_tokens, HEAD_DIM: tl.constexpr, + KCACHE_X: tl.constexpr, ): + # `n_tokens` is used to specify the number of tokens to copy for each sequence + # When n_tokens > 1, tokens from different sequences are packed into the first dimension of the grid, + # `seq_lengths` must be the lengths of sequences counting the number of tokens to copy + # E.g. if n_tokens = 5, seq_lengths = [12, 15], then the already-copied position ids are [0-6, 0-9] + # for the two sequences, respectively. And the position ids to be copied are [7-11, 9-14]. + # When n_tokens = 1, consider token idx as the sequence idx, since it's only used during regular decoding stage cur_token_idx = tl.program_id(0) - cur_seq_idx = cur_token_idx // n - cur_token_shift = cur_token_idx - (n * (cur_seq_idx + 1)) - # cur_token_shift = cur_token_idx - n * cur_seq_idx + cur_seq_idx = cur_token_idx // n_tokens + # `cur_token_shift` is only valid and functional when `n_tokens` > 1 + cur_token_shift = cur_token_idx - (n_tokens * (cur_seq_idx + 1)) cur_kv_head_idx = tl.program_id(1) + split_x_idx = tl.program_id(2) - past_kv_seq_len = tl.load(context_lengths + cur_seq_idx) + cur_token_shift + past_kv_seq_len = tl.load(seq_lengths + cur_seq_idx) + cur_token_shift last_bt_block_idx = past_kv_seq_len // block_size block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb) offset_last_block = past_kv_seq_len % block_size - offsets_dmodel = tl.arange(0, HEAD_DIM) - offsets_kv = cur_token_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd - kv = tl.load(KV + offsets_kv) - offsets_kvcache = ( - block_id * stride_cacheb - + cur_kv_head_idx * stride_cacheh - + offset_last_block * stride_cachebs - + offsets_dmodel * stride_cached + offsets_dmodel = split_x_idx * KCACHE_X + tl.arange(0, KCACHE_X) + offsets_k = cur_token_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd + k = tl.load(K + offsets_k) + offsets_kcache = ( + block_id * stride_kcb + + cur_kv_head_idx * stride_kch + + split_x_idx * stride_kcsplit_x + + offset_last_block * stride_kcs + + tl.arange(0, KCACHE_X) ) - tl.store(KVCache + offsets_kvcache, kv) + tl.store(KCache + offsets_kcache, k) return # Triton 2.1.0 @triton.jit def _copy_to_kvcache_seqlen1_kernel( - K, # K - V, # V - KCache, # KCache - VCache, # VCache + K, + V, + KCache, + VCache, BLOCK_TABLES, context_lengths, stride_kt, @@ -62,18 +75,20 @@ def _copy_to_kvcache_seqlen1_kernel( stride_vt, stride_vh, stride_vd, - stride_cachekb, - stride_cachekh, - stride_cachekbs, - stride_cachekd, - stride_cachevb, - stride_cachevh, - stride_cachevbs, - stride_cachevd, + stride_kcb, + stride_kch, + stride_kcsplit_x, + stride_kcs, + stride_kcd, + stride_vcb, + stride_vch, + stride_vcs, + stride_vcd, stride_bts, stride_btb, block_size, HEAD_DIM: tl.constexpr, + KCACHE_X: tl.constexpr, ): cur_seq_idx = tl.program_id(0) cur_kv_head_idx = tl.program_id(1) @@ -83,33 +98,42 @@ def _copy_to_kvcache_seqlen1_kernel( block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb) offsets_in_last_block = past_kv_seq_len % block_size - offsets_dmodel = tl.arange(0, HEAD_DIM) - offsets_k = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd - offsets_v = cur_seq_idx * stride_vt + cur_kv_head_idx * stride_vh + offsets_dmodel * stride_vd - k = tl.load(K + offsets_k) - v = tl.load(V + offsets_v) + range_x = tl.arange(0, KCACHE_X) + offsets_dmodel_x_partition = tl.arange(0, KCACHE_X) - offsets_kcache = ( - block_id * stride_cachekb - + cur_kv_head_idx * stride_cachekh - + offsets_in_last_block * stride_cachekbs - + offsets_dmodel * stride_cachekd - ) - offsets_vcache = ( - block_id * stride_cachevb - + cur_kv_head_idx * stride_cachevh - + offsets_in_last_block * stride_cachevbs - + offsets_dmodel * stride_cachevd - ) + for split_x in tl.static_range(HEAD_DIM // KCACHE_X): + offsets_dmodel_x_partition = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X) + offsets_k = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel_x_partition * stride_kd + k = tl.load(K + offsets_k) + offsets_v = cur_seq_idx * stride_vt + cur_kv_head_idx * stride_vh + offsets_dmodel_x_partition * stride_vd + v = tl.load(V + offsets_v) - tl.store(KCache + offsets_kcache, k) - tl.store(VCache + offsets_vcache, v) + offsets_kcache = ( + block_id * stride_kcb + + cur_kv_head_idx * stride_kch + + split_x * stride_kcsplit_x + + offsets_in_last_block * stride_kcs + + range_x + ) + tl.store(KCache + offsets_kcache, k) + offsets_vcache = ( + block_id * stride_vcb + + cur_kv_head_idx * stride_vch + + offsets_in_last_block * stride_vcs + + offsets_dmodel_x_partition * stride_vcd + ) + tl.store(VCache + offsets_vcache, v) return def copy_k_to_blocked_cache( - k: torch.Tensor, k_cache: torch.Tensor, kv_lengths: torch.Tensor, block_tables: torch.Tensor, n: int = 1 + k: torch.Tensor, + k_cache: torch.Tensor, + kv_lengths: torch.Tensor, + block_tables: torch.Tensor, + n: int = 1, + use_new_kcache_layout: bool = False, ): """ Copy keys or values to the blocked key/value cache during decoding stage. @@ -118,16 +142,17 @@ def copy_k_to_blocked_cache( k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1. [bsz * n, num_kv_heads, head_dim] - Keys or values with seq len n k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache. + new KCache Layout [num_blocks, num_kv_heads, head_dim // x, block_size, x] kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence. block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence. n (int): Number of tokens to copy for each sequence. Default to 1. + use_new_kcache_layout (bool): Whether to use the new layout for kcache. Default to False. """ - assert k.size(-1) == k_cache.size(-1), "Incompatible head dim" assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache." - - k = k.reshape(-1, k.size(-2), k.size(-1)) if k.dim() == 4 else k - assert k.dim() == 3, f"Invalid k dim {k.dim()}" - bsz, num_kv_heads, head_dim = k.shape + if k.dim() == 4: + k = k.reshape(-1, k.size(-2), k.size(-1)) + k_shape = k.shape + bsz, num_kv_heads, head_dim = k_shape # NOTE when n > 1, the shape of k is [bsz * n, num_kv_heads, head_dim] if n > 1: assert bsz % n == 0, "Each sequence should have the same number of tokens to be copied" @@ -139,12 +164,24 @@ def copy_k_to_blocked_cache( f" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}" ) + k_cache_shape = k_cache.shape # Modify if the shape of kv cahce is changed. - block_size = k_cache.size(-2) + block_size = k_cache_shape[-2] + + x = head_dim + stride_kcsplit_x, stride_kcs, stride_kcd = 0, k_cache.stride(2), k_cache.stride(3) + if use_new_kcache_layout: + # when using kcache layout [num_blocks, num_kv_heads, head_dim // x, block_size, x] + assert ( + len(k_cache_shape) == 5 + and k_cache_shape[1] == k_shape[1] + and k_cache_shape[2] * k_cache_shape[4] == k_shape[2] + ), f"Incompatible k_cache shape {k_cache_shape} with k shape {k_shape}" + x = k_cache.size(-1) + stride_kcsplit_x, stride_kcs, stride_kcd = k_cache.stride()[2:] num_warps = 8 if head_dim > 128 else 4 - - grid = (bsz * n, num_kv_heads) + grid = (bsz * n, num_kv_heads, head_dim // x) _copy_to_kcache_seqlen_n_kernel[grid]( k, k_cache, @@ -155,13 +192,15 @@ def copy_k_to_blocked_cache( k.stride(2), k_cache.stride(0), k_cache.stride(1), - k_cache.stride(2), - k_cache.stride(3), + stride_kcsplit_x, + stride_kcs, + stride_kcd, block_tables.stride(0), block_tables.stride(1), block_size, - n=n, + n_tokens=n, HEAD_DIM=head_dim, + KCACHE_X=x, num_warps=num_warps, ) @@ -173,6 +212,7 @@ def copy_kv_to_blocked_cache( v_cache: torch.Tensor, kv_lengths: torch.Tensor, block_tables: torch.Tensor, + use_new_kcache_layout: bool = False, ): """ Copy keys or values to the blocked key/value cache during decoding stage. @@ -184,19 +224,30 @@ def copy_kv_to_blocked_cache( v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked value cache. kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence. block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence. + use_new_kcache_layout (bool): Whether to use the new layout for kcache. Default to False. """ - assert k.size(-1) == k_cache.size(-1), "Incompatible head dim" - assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache." + k_cache_shape = k_cache.shape + v_cache_shape = v_cache.shape + + if use_new_kcache_layout: + assert ( + len(k_cache_shape) == 5 + and k_cache_shape[1] == v_cache_shape[1] + and k_cache_shape[2] * k_cache_shape[4] == v_cache_shape[3] + ), f"Invalid KCache shape {k_cache_shape} and VCache shape {v_cache_shape}" + else: + assert k.size(-1) == k_cache_shape[-1], "Incompatible head dim" + assert ( + k_cache_shape == v_cache_shape + ), f"Incompatible KCache shape {k_cache_shape} and VCache shape {v_cache_shape}" + assert v.size(-1) == v_cache_shape[-1], "Incompatible head dim" + k = k.squeeze(1) if k.dim() == 4 else k assert k.dim() == 3, f"Incompatible k dim {k.dim()}" - - assert v.size(-1) == v_cache.size(-1), "Incompatible head dim" - assert v.dtype == v_cache.dtype, "Expected consistent dtype for tensor and cache." v = v.squeeze(1) if v.dim() == 4 else v assert v.dim() == 3, f"Incompatible v dim {v.dim()}" bsz, num_kv_heads, head_dim = k.shape - assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, ( f"Got incompatible batch size (number of seqs):\n" f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; " @@ -206,6 +257,12 @@ def copy_kv_to_blocked_cache( # Modify if the shape of kv cahce is changed. block_size = k_cache.size(-2) + x = head_dim + stride_kcsplit_x, stride_kcs, stride_kcd = 0, k_cache.stride(2), k_cache.stride(3) + if use_new_kcache_layout: + x = k_cache.size(-1) + stride_kcsplit_x, stride_kcs, stride_kcd = k_cache.stride()[2:] + num_warps = 8 if head_dim > 128 else 4 grid = (bsz, num_kv_heads) _copy_to_kvcache_seqlen1_kernel[grid]( @@ -223,8 +280,9 @@ def copy_kv_to_blocked_cache( v.stride(2), k_cache.stride(0), k_cache.stride(1), - k_cache.stride(2), - k_cache.stride(3), + stride_kcsplit_x, + stride_kcs, + stride_kcd, v_cache.stride(0), v_cache.stride(1), v_cache.stride(2), @@ -233,5 +291,6 @@ def copy_kv_to_blocked_cache( block_tables.stride(1), block_size, HEAD_DIM=head_dim, + KCACHE_X=x, num_warps=num_warps, ) diff --git a/colossalai/kernel/triton/no_pad_rotary_embedding.py b/colossalai/kernel/triton/no_pad_rotary_embedding.py index ad3946353..e0da816bd 100644 --- a/colossalai/kernel/triton/no_pad_rotary_embedding.py +++ b/colossalai/kernel/triton/no_pad_rotary_embedding.py @@ -1,3 +1,4 @@ +import warnings from typing import Optional import torch @@ -85,8 +86,8 @@ def rotary_embedding_kernel( mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), ) - handle_k = cur_head_idx % KV_GROUP_NUM == 0 - if handle_k: + handle_kv = cur_head_idx % KV_GROUP_NUM == 0 + if handle_kv: k_head_idx = cur_head_idx // KV_GROUP_NUM off_k0 = ( tokens_range[:, None, None] * k_token_stride @@ -385,6 +386,7 @@ def decoding_fused_rotary_embedding_kernel( v_cache, BLOCK_TABLES, context_lengths, + x, q_token_stride, q_head_stride, k_token_stride, @@ -392,10 +394,15 @@ def decoding_fused_rotary_embedding_kernel( head_dim_stride, cos_token_stride, cos_stride, - cache_b_stride, - cache_h_stride, - cache_bs_stride, - cache_d_stride, + kcb_stride, + kch_stride, + kcsplit_x_stride, + kcs_stride, + kcd_stride, + vcb_stride, + vch_stride, + vcs_stride, + vcd_stride, bts_stride, btb_stride, block_size, @@ -424,8 +431,8 @@ def decoding_fused_rotary_embedding_kernel( tl.store(q + off_q0, out_q0) tl.store(q + off_q1, out_q1) - handle_k = cur_head_idx % KV_GROUP_NUM == 0 - if handle_k: + handle_kv = cur_head_idx % KV_GROUP_NUM == 0 + if handle_kv: cur_k_head_idx = cur_head_idx // KV_GROUP_NUM off_kv = cur_token_idx * k_token_stride + cur_k_head_idx * k_head_stride off_k0 = off_kv + dim_range0 * head_dim_stride @@ -443,17 +450,18 @@ def decoding_fused_rotary_embedding_kernel( last_block_idx = past_kv_seq_len // block_size block_ids = tl.load(BLOCK_TABLES + cur_token_idx * bts_stride + last_block_idx * btb_stride) offsets_in_last_block = past_kv_seq_len % block_size + offsets_cache_base = block_ids * kcb_stride + cur_k_head_idx * kch_stride k_range0 = ( - block_ids * cache_b_stride - + cur_k_head_idx * cache_h_stride - + offsets_in_last_block * cache_bs_stride - + dim_range0 * cache_d_stride + offsets_cache_base + + offsets_in_last_block * kcs_stride + + (dim_range0 // x) * kcsplit_x_stride + + (dim_range0 % x) * kcd_stride ) k_range1 = ( - block_ids * cache_b_stride - + cur_k_head_idx * cache_h_stride - + offsets_in_last_block * cache_bs_stride - + dim_range1 * cache_d_stride + offsets_cache_base + + offsets_in_last_block * kcs_stride + + (dim_range1 // x) * kcsplit_x_stride + + (dim_range1 % x) * kcd_stride ) tl.store(k_cache + k_range0, out_k0) tl.store(k_cache + k_range1, out_k1) @@ -461,10 +469,10 @@ def decoding_fused_rotary_embedding_kernel( off_v = off_kv + dim_range * head_dim_stride loaded_v = tl.load(v + off_v) v_range = ( - block_ids * cache_b_stride - + cur_k_head_idx * cache_h_stride - + offsets_in_last_block * cache_bs_stride - + dim_range * cache_d_stride + block_ids * vcb_stride + + cur_k_head_idx * vch_stride + + offsets_in_last_block * vcs_stride + + dim_range * vcd_stride ) tl.store(v_cache + v_range, loaded_v) @@ -532,6 +540,7 @@ def rotary_embedding( num_warps=num_warps, ) else: + warnings.warn("Fused rotary embedding Triton kernel will be deprecated as the new kcache layout is supported") grid = (triton.next_power_of_2(q_head_num), q_total_tokens) fused_rotary_embedding_kernel_v2[grid]( q, @@ -573,6 +582,7 @@ def decoding_fused_rotary_embedding( v_cache: Optional[torch.Tensor] = None, block_tables: Optional[torch.Tensor] = None, kv_lengths: Optional[torch.Tensor] = None, + use_new_kcache_layout: bool = False, ): """ Args: @@ -588,8 +598,6 @@ def decoding_fused_rotary_embedding( """ q_total_tokens, q_head_num, head_dim = q.shape assert q.size(0) == k.size(0) == v.size(0) - assert k.size(1) == v.size(1) - assert k_cache.size(-1) == v_cache.size(-1) if head_dim >= 512: num_warps = 16 @@ -597,18 +605,22 @@ def decoding_fused_rotary_embedding( num_warps = 8 else: num_warps = 4 - - q_token_stride = q.stride(0) - q_head_stride = q.stride(1) - head_dim_stride = q.stride(2) - - k_token_stride = k.stride(0) - k_head_stride = k.stride(1) k_head_num = k.size(1) kv_group_num = q_head_num // k_head_num - cos_token_stride = cos.stride(0) - cos_stride = cos.stride(1) + # For KCache and VCache with the same layout + x = head_dim + kcsplit_x_stride, kcs_stride, kcd_stride = 0, k_cache.stride(2), k_cache.stride(3) + # For KCache layout [num_blocks, num_kv_heads, head_dim//x, block_size, x] + if use_new_kcache_layout: + assert ( + k_cache.dim() == 5 + and k_cache.shape[1] == v_cache.shape[1] + and k_cache.shape[2] * k_cache.shape[4] == v_cache.shape[3] + ), f"Invalid KCache shape {k_cache.shape} and VCache shape {v_cache.shape}" + x = k_cache.size(-1) + kcsplit_x_stride, kcs_stride, kcd_stride = k_cache.stride()[-3:] + grid = (q_head_num, q_total_tokens) decoding_fused_rotary_embedding_kernel[grid]( q, @@ -620,17 +632,23 @@ def decoding_fused_rotary_embedding( v_cache, block_tables, kv_lengths, - q_token_stride, - q_head_stride, - k_token_stride, - k_head_stride, - head_dim_stride, - cos_token_stride, - cos_stride, + x, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + q.stride(2), + cos.stride(0), + cos.stride(1), k_cache.stride(0), k_cache.stride(1), - k_cache.stride(2), - k_cache.stride(3), + kcsplit_x_stride, + kcs_stride, + kcd_stride, + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), block_tables.stride(0), block_tables.stride(1), k_cache.size(-2), diff --git a/examples/inference/benchmark_ops/benchmark_decoding_attn.py b/examples/inference/benchmark_ops/benchmark_decoding_attn.py index ae104c807..1a80961a7 100644 --- a/examples/inference/benchmark_ops/benchmark_decoding_attn.py +++ b/examples/inference/benchmark_ops/benchmark_decoding_attn.py @@ -6,6 +6,7 @@ from tests.test_infer.test_ops.triton.kernel_utils import ( convert_kv_unpad_to_padded, create_attention_mask, generate_caches_and_block_tables_v2, + generate_caches_and_block_tables_v3, torch_attn_ref, ) from tests.test_infer.test_ops.triton.test_decoding_attn import prepare_data @@ -29,9 +30,9 @@ configs = [ x_vals=[2**i for i in range(8, 14)], # x_vals=[x for x in range(256, 8192, 256)], line_arg="provider", - line_vals=["torch", "triton"], - line_names=["Torch", "Triton"], - styles=[("red", "-"), ("blue", "-")], + line_vals=["torch", "triton", "triton_new_kcache_layout"], + line_names=["Torch", "Triton", "Triton New KCache Layout"], + styles=[("red", "-"), ("blue", "-"), ("yellow", "-")], ylabel="ms", plot_name=f"decoding-block_size-{BLOCK_SIZE}-batch{BATCH}", args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1}, @@ -62,6 +63,14 @@ def bench_kernel( bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device ) max_seq_len_in_b = kv_lengths.max().item() # for random lengths + # the maximum block length splitted on kv should be the kv cache block size + kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size + sm_scale = 1.0 / (HEAD_DIM**0.5) + output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device) + mid_output = torch.empty( + size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device + ) + mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device) quantiles = [0.5, 0.2, 0.8] if provider == "torch": @@ -81,19 +90,11 @@ def bench_kernel( HEAD_DIM, ) ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) - if provider == "triton": + elif provider == "triton": k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device ) block_tables = block_tables.to(device=device) - # the maximum block length splitted on kv should be the kv cache block size - kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size - output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device) - mid_output = torch.empty( - size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device - ) - mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device) - sm_scale = 1.0 / (HEAD_DIM**0.5) fn = lambda: flash_decoding_attention( # Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1), # refer to attention forward in modeling. @@ -111,6 +112,29 @@ def bench_kernel( kv_group_num=kv_group_num, ) # [bsz, 1, num_heads, head_dim] ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) + elif provider == "triton_new_kcache_layout": + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3( + k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + ) + block_tables = block_tables.to(device=device) + fn = lambda: flash_decoding_attention( + # Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1), + # refer to attention forward in modeling. + q.squeeze(2), + k_cache, + v_cache, + kv_lengths, + block_tables, + block_size, + max_seq_len_in_b, + output, + mid_output, + mid_output_lse, + sm_scale=sm_scale, + kv_group_num=kv_group_num, + use_new_kcache_layout=True, + ) + ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) return ms, min_ms, max_ms diff --git a/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py b/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py index 9c9fdcebd..6a499ccf2 100644 --- a/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py +++ b/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py @@ -24,18 +24,20 @@ configs = [ x_vals=[2**i for i in range(4, 11)], line_arg="provider", line_vals=[ - "no_fused_triton_rotary_emb_func", - "fused_triton_rotary_emb_func", - "no_fused_cuda_rotary_emb_func", - "fused_cuda_rotary_emb_func", + "triton_rotary_emb_func", + "triton_fused_rotary_emb_func", + "triton_fused_rotary_emb_func_new_kcache_layout", + "cuda_rotary_emb_func", + "cuda_fused_rotary_emb_func", ], line_names=[ - "no_fused_triton_rotary_emb_func", - "fused_triton_rotary_emb_func", - "no_fused_cuda_rotary_emb_func", - "fused_cuda_rotary_emb_func", + "triton_rotary_emb_func", + "triton_fused_rotary_emb_func", + "triton_fused_rotary_emb_func(new layout)", + "cuda_rotary_emb_func", + "cuda_fused_rotary_emb_func", ], - styles=[("red", "-"), ("blue", "-"), ("green", "-"), ("yellow", "-")], + styles=[("red", "-"), ("blue", "-"), ("purple", "-"), ("green", "-"), ("yellow", "-")], ylabel="ms", plot_name=f"rotary_emb-batch-{BATCH}", args={"num_kv_heads": 16}, @@ -91,31 +93,44 @@ def benchmark_rotary_emb( kv_seq_lengths = past_kv_seq_lengths + 1 block_tables = block_tables.to(device="cuda") - if provider == "no_fused_triton_rotary_emb_func": + quantiles = [0.5, 0.2, 0.8] + if provider == "triton_rotary_emb_func": fn = lambda: [ rotary_embedding(new_q, new_k, cos, sin), copy_kv_to_blocked_cache( new_k, new_v, k_cache, v_cache, kv_lengths=kv_seq_lengths, block_tables=block_tables ), ] - elif provider == "fused_triton_rotary_emb_func": + elif provider == "triton_fused_rotary_emb_func": fn = lambda: decoding_fused_rotary_embedding( new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths ) - elif provider == "no_fused_cuda_rotary_emb_func": + elif provider == "triton_fused_rotary_emb_func_new_kcache_layout": + x = 16 // torch.tensor([], dtype=dtype).element_size() + kcache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x) + k_cache = torch.zeros(size=kcache_shape, dtype=dtype, device="cuda") + block_tables = mock_alloc_block_table_and_kvcache_v3( + k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size + ) + mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size) + block_tables = block_tables.to(device="cuda") + fn = lambda: decoding_fused_rotary_embedding( + new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths, use_new_kcache_layout=True + ) + elif provider == "cuda_rotary_emb_func": fn = lambda: [ inference_ops.rotary_embedding(new_q, new_k, cos, sin, True), inference_ops.decode_kv_cache_memcpy(new_k, new_v, new_k_cache, v_cache, kv_seq_lengths, block_tables), ] - elif provider == "fused_cuda_rotary_emb_func": + elif provider == "cuda_fused_rotary_emb_func": fn = lambda: inference_ops.rotary_embedding_and_cache_copy( new_q, new_k, new_v, cos, sin, new_k_cache, v_cache, kv_seq_lengths, block_tables, True ) else: raise ValueError("Undefined provider") - ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) - return ms + ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=quantiles) + return ms, min_ms, max_ms if __name__ == "__main__": diff --git a/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py b/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py index 8121eba59..03f797308 100644 --- a/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py +++ b/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py @@ -14,7 +14,7 @@ except ImportError: inference_ops = InferenceOpsLoader().load() -HEAD_DIM = 4 +HEAD_DIM = 128 BATCH = 16 BLOCK_SIZE = 32 SAME_LEN = True @@ -25,9 +25,9 @@ configs = [ x_names=["KV_SEQ_LEN"], x_vals=[2**i for i in range(8, 13)], line_arg="provider", - line_vals=["torch_copy_func", "triton_copy_func", "cuda_copy_func"], - line_names=["torch_copy_func", "triton_copy_func", "cuda_copy_func"], - styles=[("red", "-"), ("blue", "-"), ("green", "-")], + line_vals=["torch_copy_func", "triton_copy_func", "triton_new_kcache_layout", "cuda_copy_func"], + line_names=["torch_copy_func", "triton_copy_func", "triton_new_kcache_layout", "cuda_copy_func"], + styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")], ylabel="ms", plot_name=f"kvcache_copy_decoding_stage-batch-{BATCH}", args={"bsz": BATCH, "block_size": 16, "max_seq_len": 8192, "num_kv_heads": 16, "same_context_len": True}, @@ -45,7 +45,7 @@ def benchmark_kvcache_copy( num_kv_heads: int, same_context_len: bool, ): - dtype = torch.float32 + dtype = torch.float16 device = get_current_device() assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len" @@ -63,11 +63,18 @@ def benchmark_kvcache_copy( ) quantiles = [0.5, 0.2, 0.8] - # TODO copy_to_cache needs to support copying both k and v at the same time in the future. if provider == "torch_copy_func": fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding") elif provider == "triton_copy_func": fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables) + elif provider == "triton_new_kcache_layout": + # NOTE New kcache layout (num_blocks, num_kv_heads, head_dim // x, block_size, x) to be applied + x = 16 // torch.tensor([], dtype=dtype).element_size() + k_cache_shape = (bsz * max_seq_len // block_size, num_kv_heads, HEAD_DIM // x, block_size, x) + k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=device) # update k_cache layout + fn = lambda: copy_kv_to_blocked_cache( + new_k, new_v, k_cache, v_cache, context_lengths, block_tables, use_new_kcache_layout=True + ) elif provider == "cuda_copy_func": _, _, k_cache, _, _, _, _, _, _ = prepare_data_new_kcache_layout( bsz, num_kv_heads, block_size, max_seq_len // block_size, context_lengths - 1, device, dtype diff --git a/tests/test_infer/test_ops/triton/test_decoding_attn.py b/tests/test_infer/test_ops/triton/test_decoding_attn.py index 5dc3c22c0..616d7868b 100644 --- a/tests/test_infer/test_ops/triton/test_decoding_attn.py +++ b/tests/test_infer/test_ops/triton/test_decoding_attn.py @@ -10,6 +10,7 @@ from tests.test_infer.test_ops.triton.kernel_utils import ( convert_kv_unpad_to_padded, create_attention_mask, generate_caches_and_block_tables_v2, + generate_caches_and_block_tables_v3, torch_attn_ref, ) from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask @@ -75,6 +76,7 @@ def prepare_data( @pytest.mark.parametrize("same_context_len", [True, False]) @pytest.mark.parametrize("q_len", [1, 5]) @pytest.mark.parametrize("use_alibi_slopes", [True, False]) +@pytest.mark.parametrize("use_new_kcache_layout", [True, False]) def test_flash_decoding( bsz: int, block_size: int, @@ -84,7 +86,15 @@ def test_flash_decoding( same_context_len: bool, q_len: int, use_alibi_slopes: bool, + use_new_kcache_layout: bool, ): + if use_new_kcache_layout and use_alibi_slopes: + # TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one, + # the code (alibi kernel) will be refactored later to avoid code duplication, when + # the whole triton flow with new k cache layout has been supported and tested. + # And tests for the alibi kernel using new kcache layout will be added then. + pytest.skip("Alibi kernel does not support new kcache layout yet.") + torch.manual_seed(123) torch.cuda.empty_cache() torch.cuda.synchronize() @@ -127,9 +137,14 @@ def test_flash_decoding( q, k_torch, v_torch, attention_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM ) - k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( - k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device - ) + if use_new_kcache_layout: + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3( + k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + ) + else: + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( + k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + ) block_tables = block_tables.to(device=device) # The maximum block length splitted on kv should be the kv cache block size kv_max_split_num = (max_kv_len_in_b + block_size - 1) // block_size @@ -165,6 +180,7 @@ def test_flash_decoding( sm_scale=sm_scale, kv_group_num=kv_group_num, q_len=q_len, + use_new_kcache_layout=use_new_kcache_layout, ) # [bsz * q_len, num_heads, head_dim] assert out_torch.shape == out_triton.shape @@ -178,4 +194,4 @@ def test_flash_decoding( if __name__ == "__main__": - test_flash_decoding(16, 32, 32, 16, 1, True, 1, True) + test_flash_decoding(16, 32, 32, 16, 1, True, 1, use_alibi_slopes=False, use_new_kcache_layout=True) diff --git a/tests/test_infer/test_ops/triton/test_kvcache_copy.py b/tests/test_infer/test_ops/triton/test_kvcache_copy.py index c4122a0c7..95126c087 100644 --- a/tests/test_infer/test_ops/triton/test_kvcache_copy.py +++ b/tests/test_infer/test_ops/triton/test_kvcache_copy.py @@ -4,7 +4,11 @@ from packaging import version from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache from colossalai.utils import get_current_device -from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token +from tests.test_infer.test_ops.triton.kernel_utils import ( + generate_caches_and_block_tables_v2, + generate_caches_and_block_tables_v3, + mock_alloc_single_token, +) try: import triton # noqa @@ -30,6 +34,7 @@ def prepare_data( n=1, device="cuda", dtype=torch.float16, + use_new_kcache_layout=False, ): assert max_seq_len > n, "max_seq_len must be greater than n" @@ -44,9 +49,14 @@ def prepare_data( kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2) - k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( - k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device - ) + if use_new_kcache_layout: + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3( + k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device + ) + else: + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( + k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device + ) block_tables = block_tables.to(device=device) new_k = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device) @@ -66,8 +76,15 @@ def prepare_data( @pytest.mark.parametrize("num_kv_heads", [16]) @pytest.mark.parametrize("same_context_len", [True, False]) @pytest.mark.parametrize("n_tokens", [1, 5]) +@pytest.mark.parametrize("use_new_kcache_layout", [True, False]) def test_copy_kv_to_caches( - bsz: int, block_size: int, max_num_blocks_per_seq: int, num_kv_heads: int, same_context_len: bool, n_tokens: int + bsz: int, + block_size: int, + max_num_blocks_per_seq: int, + num_kv_heads: int, + same_context_len: bool, + n_tokens: int, + use_new_kcache_layout: bool, ): torch.manual_seed(123) torch.cuda.empty_cache() @@ -89,6 +106,7 @@ def test_copy_kv_to_caches( n_tokens, device=device, dtype=dtype, + use_new_kcache_layout=use_new_kcache_layout, ) k_source = new_k.view(-1, new_k.size(-2), new_k.size(-1)) v_source = new_v.view(-1, new_v.size(-2), new_v.size(-1)) @@ -98,7 +116,9 @@ def test_copy_kv_to_caches( offsets_in_block = past_kv_seq_lengths % block_size # Copy k (or v) to k (or v) cache - copy_k_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables, n=n_tokens) + copy_k_to_blocked_cache( + new_k, k_cache, kv_seq_lengths, block_tables, n=n_tokens, use_new_kcache_layout=use_new_kcache_layout + ) # Reshape target k from k cache to compare if matching with original tensor # Mainly to handle cases of n_tokens > 1 k_target = [] @@ -110,26 +130,39 @@ def test_copy_kv_to_caches( while tokens_left > 0: tokens_to_fill = min(block_size - offset, tokens_left) curr_block_id = block_table[curr_kv_len // block_size] - k_target.append(k_cache[curr_block_id, :, offset : offset + tokens_to_fill, :]) + if use_new_kcache_layout: + k_target.append(k_cache[curr_block_id, :, :, offset : offset + tokens_to_fill, :]) + else: + k_target.append(k_cache[curr_block_id, :, offset : offset + tokens_to_fill, :]) curr_kv_len += tokens_to_fill tokens_left -= tokens_to_fill offset = 0 - k_target = torch.concat(k_target, dim=1).transpose(0, 1).contiguous() # [bsz * n, num_kv_heads, head_dim] - + if use_new_kcache_layout: + k_target = torch.concat(k_target, dim=2).permute(2, 0, 1, 3).contiguous() + k_target = k_target.reshape(bsz * n_tokens, num_kv_heads, HEAD_DIM) + else: + k_target = torch.concat(k_target, dim=1).transpose(0, 1).contiguous() # [bsz * n, num_kv_heads, head_dim] assert k_target.shape == k_source.shape assert torch.equal(k_target, k_source) if n_tokens == 1: # Copy k and v to k/v caches k_cache = k_cache_copy - copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables) - k_target = k_cache_copy[target_block_ids, :, offsets_in_block, :] - v_target = v_cache[target_block_ids, :, offsets_in_block, :] + copy_kv_to_blocked_cache( + new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables, use_new_kcache_layout=use_new_kcache_layout + ) + + if use_new_kcache_layout: + k_target = k_cache[target_block_ids, :, :, offsets_in_block, :] + k_target = k_target.contiguous().reshape(bsz * n_tokens, num_kv_heads, HEAD_DIM) + else: + k_target = k_cache[target_block_ids, :, offsets_in_block, :] assert k_target.shape == k_source.shape assert torch.equal(k_target, k_source) + v_target = v_cache[target_block_ids, :, offsets_in_block, :] assert v_target.shape == v_source.shape assert torch.equal(v_target, v_source) if __name__ == "__main__": - test_copy_kv_to_caches(4, 32, 8, 16, True) + test_copy_kv_to_caches(4, 32, 8, 16, True, n_tokens=1) diff --git a/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py b/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py index 5b952730a..87eb38135 100644 --- a/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py @@ -4,7 +4,10 @@ from packaging import version from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb from colossalai.kernel.triton import decoding_fused_rotary_embedding -from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2 +from tests.test_infer.test_ops.triton.kernel_utils import ( + mock_alloc_block_table_and_kvcache_v2, + mock_alloc_block_table_and_kvcache_v3, +) try: import triton # noqa @@ -36,7 +39,8 @@ def torch_rotary_emb(x, cos, sin): @pytest.mark.parametrize("H", [32]) @pytest.mark.parametrize("D", [64]) @pytest.mark.parametrize("dtype", [torch.float32]) -def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): +@pytest.mark.parametrize("use_new_kcache_layout", [True, False]) +def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout): TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN # our crafted op equals to Transformers x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D) @@ -57,28 +61,40 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") k_shape = (TOTAL_TOKENS, H, D) k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda") - cos_shape = (TOTAL_TOKENS, D // 2) - cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") - sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") - cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, block_size, D) - k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda") v = torch.randn_like(k) - v_cache = torch.zeros_like(k_cache) - past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda") - block_tables = mock_alloc_block_table_and_kvcache_v2( - k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size - ) new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda") new_q = torch.randn_like(new_k) new_v = torch.randn_like(new_k) + cos_shape = (TOTAL_TOKENS, D // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + + past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda") + v_cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, block_size, D) + v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device="cuda") + + if use_new_kcache_layout: + x = 16 // torch.tensor([], dtype=dtype).element_size() + kcache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, D // x, block_size, x) + k_cache = torch.zeros(size=kcache_shape, dtype=dtype, device="cuda") + block_tables = mock_alloc_block_table_and_kvcache_v3( + k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size + ) + else: + k_cache = torch.zeros_like(v_cache) + block_tables = mock_alloc_block_table_and_kvcache_v2( + k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size + ) kv_seq_lengths = past_kv_seq_lengths + 1 block_tables = block_tables.to(device="cuda") q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) - decoding_fused_rotary_embedding(new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths) + decoding_fused_rotary_embedding( + new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths, use_new_kcache_layout + ) assert torch.allclose(new_q, q_ref, atol=1e-4, rtol=1e-4) if __name__ == "__main__": - test_rotary_emb(4, 64, 32, 64, torch.float32) + test_rotary_emb(4, 64, 32, 64, torch.float32, use_new_kcache_layout=True)