diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 82a922650..4d2c17db1 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -11,7 +11,7 @@ if HAS_TRITON: from .context_attn_unpad import context_attention_unpadded from .flash_decoding import flash_decoding_attention from .fused_rotary_embedding import fused_rotary_embedding - from .kvcache_copy import copy_kv_to_blocked_cache + from .kvcache_copy import copy_k_to_blocked_cache, copy_kv_to_blocked_cache from .no_pad_rotary_embedding import decoding_fused_rotary_embedding, rotary_embedding from .rms_layernorm import rms_layernorm from .rotary_cache_copy import get_xine_cache @@ -20,6 +20,7 @@ if HAS_TRITON: __all__ = [ "context_attention_unpadded", "flash_decoding_attention", + "copy_k_to_blocked_cache", "copy_kv_to_blocked_cache", "softmax", "rms_layernorm", diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py index d351b20da..e1ccffe53 100644 --- a/colossalai/kernel/triton/flash_decoding.py +++ b/colossalai/kernel/triton/flash_decoding.py @@ -9,13 +9,14 @@ import triton.language as tl # Triton 2.1.0 @triton.jit def _flash_decoding_fwd_kernel( - Q, # [batch_size, head_num, q_len(1), head_dim] + Q, # [batch_size * q_len, head_num, head_dim] KCache, # [num_blocks, num_kv_heads, block_size, head_dim] VCache, # [num_blocks, num_kv_heads, block_size, head_dim] block_tables, # [batch_size, max_blocks_per_sequence] - mid_o, # [batch_size, head_num, kv_split_num, head_dim] - mid_o_lse, # [batch_size, head_num, kv_split_num] + mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim] + mid_o_lse, # [batch_size * q_len, head_num, kv_split_num] kv_seq_len, # [batch_size] + q_len, batch_size, stride_qt, stride_qh, @@ -39,44 +40,37 @@ def _flash_decoding_fwd_kernel( BLOCK_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr, ): - cur_seq_idx = tl.program_id(0) + cur_token_idx = tl.program_id(0) + cur_seq_idx = cur_token_idx // q_len if cur_seq_idx >= batch_size: return cur_head_idx = tl.program_id(1) block_start_kv = tl.program_id(2) # for splitting k/v - cur_kv_head_idx = cur_head_idx // KV_GROUPS - offsets_dmodel = tl.arange(0, HEAD_DIM) - # NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same # TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE) # and then support calculating multiple kv cache blocks on an instance tl.static_assert(BLOCK_KV == BLOCK_SIZE) - - # get the current (kv) sequence length from provided context lengths tensor + # get the current (kv) sequence length cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + if block_start_kv * BLOCK_KV >= cur_kv_seq_len: + return - offsets_q = cur_seq_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd + offsets_dmodel = tl.arange(0, HEAD_DIM) + offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd q = tl.load(Q + offsets_q) - # block table for the current sequence block_table_ptr = block_tables + cur_seq_idx * stride_bts - - # actually current block table current block start idx # cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE) - cur_bt_start_idx = block_start_kv - cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb) - - if block_start_kv * BLOCK_KV >= cur_kv_seq_len: - return - + # cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb) + cur_block_id = tl.load(block_table_ptr + block_start_kv * stride_btb) cur_occupied_size = tl.where( (block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE ) tl.device_assert(cur_occupied_size >= 0) + cur_kv_head_idx = cur_head_idx // KV_GROUPS offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh - K_block_ptr = tl.make_block_ptr( base=KCache + offset_kvcache, shape=(cur_occupied_size, HEAD_DIM), @@ -115,14 +109,14 @@ def _flash_decoding_fwd_kernel( acc = acc / l offsets_mid_o = ( - cur_seq_idx * stride_mid_ot + cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + block_start_kv * stride_mid_ob + offsets_dmodel * stride_mid_od ) tl.store(mid_o + offsets_mid_o, acc) offsets_mid_o_lse = ( - cur_seq_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb + cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb ) # logsumexp L^(j) = m^(j) + log(l^(j)) tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l)) @@ -135,6 +129,7 @@ def _flash_decoding_fwd_reduce_kernel( mid_o_lse, # [batch_size, head_num, kv_split_num] O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim] kv_seq_len, + q_len, batch_size, stride_mid_ot, stride_mid_oh, @@ -149,7 +144,8 @@ def _flash_decoding_fwd_reduce_kernel( BLOCK_KV: tl.constexpr, HEAD_DIM: tl.constexpr, ): - cur_seq_idx = tl.program_id(0) + cur_token_idx = tl.program_id(0) + cur_seq_idx = cur_token_idx // q_len if cur_seq_idx >= batch_size: return cur_head_idx = tl.program_id(1) @@ -164,8 +160,8 @@ def _flash_decoding_fwd_reduce_kernel( l = 0.0 # sum exp acc = tl.zeros([HEAD_DIM], dtype=tl.float32) - offsets_mid_o = cur_seq_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel - offset_mid_lse = cur_seq_idx * stride_o_lset + cur_head_idx * stride_o_lseh + offsets_mid_o = cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel + offset_mid_lse = cur_token_idx * stride_o_lset + cur_head_idx * stride_o_lseh for block_i in range(0, kv_split_num, 1): mid_o_block = tl.load(mid_o + offsets_mid_o + block_i * stride_mid_ob) lse = tl.load(mid_o_lse + offset_mid_lse + block_i * stride_o_lseb) @@ -179,7 +175,7 @@ def _flash_decoding_fwd_reduce_kernel( m_i = m_ij acc = acc / l - offsets_O = cur_seq_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel + offsets_O = cur_token_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel tl.store(O + offsets_O, acc.to(O.type.element_ty)) return @@ -199,12 +195,14 @@ def flash_decoding_attention( mid_output_lse: torch.Tensor = None, sm_scale: int = None, kv_group_num: int = 1, + q_len: int = 1, ): """ Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage. Args: - q (torch.Tensor): [bsz, num_heads, head_dim] + q (torch.Tensor): [bsz * q_len, num_heads, head_dim] + q_len > 1 only for verification process in speculative-decoding. k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] kv_seq_len (torch.Tensor): [batch_size] @@ -212,19 +210,25 @@ def flash_decoding_attention( block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence] max_seq_len_in_batch (int): Maximum sequence length in the batch. output (torch.Tensor): [bsz, num_heads * head_dim] - mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim] + mid_output (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num, head_dim] Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`. - mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num] + q_len > 1 only for verification process in speculative-decoding. + mid_output_lse (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num] Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`. + q_len > 1 only for verification process in speculative-decoding. block_size (int): Size of each block in the blocked key/value cache. num_kv_group (int, optional): Number of key/value groups. Defaults to 1. + q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens). + Defaults to 1. Returns: - Output tensor with shape [bsz, num_heads * head_dim] + Output tensor with shape [bsz * q_len, num_heads * head_dim] """ q = q.squeeze() if q.dim() == 4 else q assert q.dim() == 3, f"Incompatible q dim: {q.dim()}" - bsz, num_heads, head_dim = q.shape + n_tokens, num_heads, head_dim = q.shape + assert n_tokens % q_len == 0, "Invalid q_len" + bsz = n_tokens // q_len assert head_dim in {32, 64, 128, 256} assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, ( @@ -247,22 +251,31 @@ def flash_decoding_attention( max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch # For compatibility (TODO revise modeling in future) kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV - mid_output = ( - torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device) - if mid_output is None - else mid_output - ) - mid_output_lse = ( - torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device) - if mid_output_lse is None - else mid_output_lse - ) + + if mid_output is None: + mid_output = torch.empty( + (bsz * q_len, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device + ) + if mid_output_lse is None: + mid_output_lse = torch.empty((bsz * q_len, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device) + if output is None: + # A hack to prevent `view` operation in modeling + output = torch.empty((bsz * q_len, num_heads * head_dim), dtype=q.dtype, device=q.device) + + assert ( + mid_output.size(2) == mid_output_lse.size(2) >= kv_max_split_num + ), "Incompatible kv split number of intermediate output tensors" + assert ( + mid_output.size(0) == mid_output_lse.size(0) >= output.size(0) == n_tokens + ), f"Incompatible first dimension of output tensors" # NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton # To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred) - grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV)) - output = torch.empty((bsz, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output - + grid = ( + triton.next_power_of_2(bsz * q_len), + num_heads, + triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV), + ) _flash_decoding_fwd_kernel[grid]( q, k_cache, @@ -271,6 +284,7 @@ def flash_decoding_attention( mid_output, mid_output_lse, kv_seq_len, + q_len, bsz, q.stride(0), q.stride(1), @@ -295,13 +309,13 @@ def flash_decoding_attention( HEAD_DIM=head_dim, ) - grid = (triton.next_power_of_2(bsz), num_heads) - + grid = (triton.next_power_of_2(bsz * q_len), num_heads) _flash_decoding_fwd_reduce_kernel[grid]( mid_output, mid_output_lse, output, kv_seq_len, + q_len, bsz, mid_output.stride(0), mid_output.stride(1), diff --git a/colossalai/kernel/triton/kvcache_copy.py b/colossalai/kernel/triton/kvcache_copy.py index 96ab922e3..871f1f6d8 100644 --- a/colossalai/kernel/triton/kvcache_copy.py +++ b/colossalai/kernel/triton/kvcache_copy.py @@ -3,6 +3,50 @@ import triton import triton.language as tl +# Triton 2.1.0 +@triton.jit +def _copy_to_kcache_seqlen_n_kernel( + KV, # K or V + KVCache, # KCache or VCache + BLOCK_TABLES, + context_lengths, + stride_kt, + stride_kh, + stride_kd, + stride_cacheb, + stride_cacheh, + stride_cachebs, + stride_cached, + stride_bts, + stride_btb, + block_size, + n, + HEAD_DIM: tl.constexpr, +): + cur_token_idx = tl.program_id(0) + cur_seq_idx = cur_token_idx // n + cur_token_shift = cur_token_idx - (n * (cur_seq_idx + 1)) + # cur_token_shift = cur_token_idx - n * cur_seq_idx + cur_kv_head_idx = tl.program_id(1) + + past_kv_seq_len = tl.load(context_lengths + cur_seq_idx) + cur_token_shift + last_bt_block_idx = past_kv_seq_len // block_size + block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts + block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb) + offset_last_block = past_kv_seq_len % block_size + offsets_dmodel = tl.arange(0, HEAD_DIM) + offsets_kv = cur_token_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd + kv = tl.load(KV + offsets_kv) + offsets_kvcache = ( + block_id * stride_cacheb + + cur_kv_head_idx * stride_cacheh + + offset_last_block * stride_cachebs + + offsets_dmodel * stride_cached + ) + tl.store(KVCache + offsets_kvcache, kv) + return + + # Triton 2.1.0 @triton.jit def _copy_to_kvcache_seqlen1_kernel( @@ -40,10 +84,11 @@ def _copy_to_kvcache_seqlen1_kernel( block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb) offsets_in_last_block = past_kv_seq_len % block_size offsets_dmodel = tl.arange(0, HEAD_DIM) - offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd + offsets_k = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd + offsets_v = cur_seq_idx * stride_vt + cur_kv_head_idx * stride_vh + offsets_dmodel * stride_vd - k = tl.load(K + offsets_kv) - v = tl.load(V + offsets_kv) + k = tl.load(K + offsets_k) + v = tl.load(V + offsets_v) offsets_kcache = ( block_id * stride_cachekb @@ -63,6 +108,64 @@ def _copy_to_kvcache_seqlen1_kernel( return +def copy_k_to_blocked_cache( + k: torch.Tensor, k_cache: torch.Tensor, kv_lengths: torch.Tensor, block_tables: torch.Tensor, n: int = 1 +): + """ + Copy keys or values to the blocked key/value cache during decoding stage. + + Args: + k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1. + [bsz * n, num_kv_heads, head_dim] - Keys or values with seq len n + k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache. + kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence. + block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence. + n (int): Number of tokens to copy for each sequence. Default to 1. + """ + assert k.size(-1) == k_cache.size(-1), "Incompatible head dim" + assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache." + + k = k.reshape(-1, k.size(-2), k.size(-1)) if k.dim() == 4 else k + assert k.dim() == 3, f"Invalid k dim {k.dim()}" + bsz, num_kv_heads, head_dim = k.shape + # NOTE when n > 1, the shape of k is [bsz * n, num_kv_heads, head_dim] + if n > 1: + assert bsz % n == 0, "Each sequence should have the same number of tokens to be copied" + bsz = bsz // n + + assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, ( + f"Got incompatible batch size (number of seqs):\n" + f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; " + f" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}" + ) + + # Modify if the shape of kv cahce is changed. + block_size = k_cache.size(-2) + + num_warps = 8 if head_dim > 128 else 4 + + grid = (bsz * n, num_kv_heads) + _copy_to_kcache_seqlen_n_kernel[grid]( + k, + k_cache, + block_tables, + kv_lengths, + k.stride(0), + k.stride(1), + k.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + block_size, + n=n, + HEAD_DIM=head_dim, + num_warps=num_warps, + ) + + def copy_kv_to_blocked_cache( k: torch.Tensor, v: torch.Tensor, diff --git a/tests/test_infer/test_ops/triton/kernel_utils.py b/tests/test_infer/test_ops/triton/kernel_utils.py index 22167ded0..f1ae45477 100644 --- a/tests/test_infer/test_ops/triton/kernel_utils.py +++ b/tests/test_infer/test_ops/triton/kernel_utils.py @@ -19,12 +19,12 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim) -def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, kv_seq_len: int, device="cuda"): - padding_mask = torch.zeros((bsz, 1, 1, kv_seq_len), dtype=torch.float32, device=device) +def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, q_len: int, kv_len: int, device="cuda"): + padding_mask = torch.zeros((bsz, 1, q_len, kv_len), dtype=torch.float32, device=device) for i in range(bsz): cur_seq_len = kv_lengths[i].item() - assert cur_seq_len <= kv_seq_len - padding_mask[i, :, :, : kv_seq_len - cur_seq_len] = float("-inf") + assert cur_seq_len <= kv_len + padding_mask[i, :, :, : kv_len - cur_seq_len] = float("-inf") return padding_mask @@ -33,12 +33,12 @@ def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, kv_seq_len: int, de # https://github.com/huggingface/transformers/blob/633215ba58fe5114d8c8d32e415a04600e010701/src/transformers/models/llama/modeling_llama.py#L350 def torch_attn_ref( q: torch.Tensor, # [bsz, num_heads, q_len, head_dim] - k: torch.Tensor, # [bsz, num_heads, kv_seq_len, head_dim] - v: torch.Tensor, # [bsz, num_heads, kv_seq_len, head_dim] - attention_mask: torch.Tensor, # [bsz, 1, seq_len, kv_seq_len] + k: torch.Tensor, # [bsz, num_heads, kv_len, head_dim] + v: torch.Tensor, # [bsz, num_heads, kv_len, head_dim] + attention_mask: torch.Tensor, # [bsz, 1, q_len, kv_len] bsz: int, - seq_len: int, - kv_seq_len: int, + q_len: int, + kv_len: int, num_heads: int, num_kv_heads: int, head_dim: int, @@ -54,22 +54,22 @@ def torch_attn_ref( qk = torch.matmul(q, k.transpose(2, 3)) attn_scores = qk / (head_dim**0.5) - assert attn_scores.shape == (bsz, num_heads, seq_len, kv_seq_len), "Invalid shape of attention scores" + + assert attn_scores.shape == (bsz, num_heads, q_len, kv_len), "Invalid shape of attention scores" # for left-side padding - if attention_mask.size() != (bsz, 1, seq_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, seq_len, kv_seq_len)}, but is {attention_mask.size()}" - ) + if attention_mask.size() != (bsz, 1, q_len, kv_len): + raise ValueError(f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}") attn_scores = attn_scores + attention_mask attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype) out = torch.matmul(attn_weights, v) - if out.size() != (bsz, num_heads, seq_len, head_dim): + if out.size() != (bsz, num_heads, q_len, head_dim): raise ValueError( - f"`attn_output` should be of size {(bsz, num_heads, seq_len, head_dim)}, but is" f" {out.size()}" + f"`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)}, but is" f" {out.size()}" ) out = out.transpose(1, 2).contiguous() - out = out.squeeze(1) + out = out.view(-1, out.size(-2), out.size(-1)) + # out [bsz * q_len, num_heads, head_dim] return out diff --git a/tests/test_infer/test_ops/triton/test_decoding_attn.py b/tests/test_infer/test_ops/triton/test_decoding_attn.py index 2ce0f9d04..77354e1bb 100644 --- a/tests/test_infer/test_ops/triton/test_decoding_attn.py +++ b/tests/test_infer/test_ops/triton/test_decoding_attn.py @@ -21,7 +21,6 @@ except ImportError: TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") -Q_LEN = 1 HEAD_DIM = 128 @@ -64,6 +63,7 @@ def prepare_data( @pytest.mark.parametrize("num_attn_heads", [16]) @pytest.mark.parametrize("kv_group_num", [1, 2, 16]) @pytest.mark.parametrize("same_context_len", [True, False]) +@pytest.mark.parametrize("q_len", [1, 5]) def test_flash_decoding( bsz: int, block_size: int, @@ -71,6 +71,7 @@ def test_flash_decoding( num_attn_heads: int, kv_group_num: int, same_context_len: bool, + q_len: int, ): torch.manual_seed(123) torch.cuda.empty_cache() @@ -82,47 +83,57 @@ def test_flash_decoding( max_seq_len = block_size * max_num_blocks_per_seq dtype = torch.float16 device = get_current_device() + q, k_unpad, v_unpad, kv_lengths = prepare_data( + bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, q_len, max_seq_len, dtype, device + ) + # The maximum sequence length in the batch (if context lengths randomly generated) + max_kv_len_in_b = kv_lengths.max().item() - q, k_unpad, v_unpad, kv_seq_lengths = prepare_data( - bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device + k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_kv_len_in_b) + v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_kv_len_in_b) + torch_padding_mask = prepare_padding_mask(kv_lengths, bsz, q_len, max_kv_len_in_b, q.device) + out_torch = torch_attn_ref( + q, k_torch, v_torch, torch_padding_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM ) + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( - k_unpad, v_unpad, kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device ) block_tables = block_tables.to(device=device) - # The maximum sequence length in the batch (if context lengths randomly generated) - max_seq_len_in_b = kv_seq_lengths.max().item() # The maximum block length splitted on kv should be the kv cache block size - kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size - output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device) + kv_max_split_num = (max_kv_len_in_b + block_size - 1) // block_size + output = torch.empty((bsz * q_len, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device) mid_output = torch.empty( - size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device + size=(bsz * q_len, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device + ) + mid_output_lse = torch.empty( + size=(bsz * q_len, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device ) - mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device) sm_scale = 1.0 / (HEAD_DIM**0.5) + # Here we use different methods to hide the q_len dimension, + # refer to attention forward function in modeling. + if q_len > 1: + q = q.transpose(1, 2).contiguous() # [bsz, q_len, num_heads, head_dim] + q = q.view(-1, q.size(-2), q.size(-1)) # [bsz * q_len, num_heads, head_dim] + else: + q = q.squeeze(2) + assert q.shape == (bsz * q_len, num_attn_heads, HEAD_DIM) + out_triton = flash_decoding_attention( - # Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1), - # refer to attention forward in modeling. - q.squeeze(2), + q, k_cache, v_cache, - kv_seq_lengths, + kv_lengths, block_tables, block_size, - max_seq_len_in_b, + max_kv_len_in_b, output, mid_output, mid_output_lse, sm_scale=sm_scale, kv_group_num=kv_group_num, - ) # [bsz, 1, num_heads, head_dim] - - k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, bsz, max_seq_len_in_b) - v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, bsz, max_seq_len_in_b) - torch_padding_mask = prepare_padding_mask(kv_seq_lengths, bsz, max_seq_len_in_b, q.device) - out_torch = torch_attn_ref( - q, k_torch, v_torch, torch_padding_mask, bsz, 1, max_seq_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM - ) + q_len=q_len, + ) # [bsz * q_len, num_heads, head_dim] assert out_torch.shape == out_triton.shape assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4) diff --git a/tests/test_infer/test_ops/triton/test_kvcache_copy.py b/tests/test_infer/test_ops/triton/test_kvcache_copy.py index b3fdd4b88..43545df79 100644 --- a/tests/test_infer/test_ops/triton/test_kvcache_copy.py +++ b/tests/test_infer/test_ops/triton/test_kvcache_copy.py @@ -2,7 +2,8 @@ import pytest import torch from packaging import version -from colossalai.kernel.triton import copy_kv_to_blocked_cache +from colossalai.inference.modeling.layers.attention import copy_to_cache +from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache from colossalai.utils import get_current_device from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token @@ -16,7 +17,7 @@ except ImportError: TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") -HEAD_DIM = 128 +HEAD_DIM = 32 def prepare_data( @@ -27,15 +28,16 @@ def prepare_data( max_num_blocks_per_seq, same_context_len, max_seq_len, + n, device, dtype=torch.float16, ): - # past_kv_seq_lengths in this test records the previous kv seq len - # (not incorporating the current input whose seq len is 1) + assert max_seq_len > n, "max_seq_len must be greater than n" + past_kv_seq_lengths = ( - torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device) + torch.tensor([max_seq_len - n for _ in range(bsz)], dtype=torch.int32, device=device) if same_context_len - else torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device) + else torch.randint(low=1, high=max_seq_len - n, size=(bsz,), dtype=torch.int32, device=device) ) num_tokens = torch.sum(past_kv_seq_lengths).item() @@ -48,14 +50,14 @@ def prepare_data( ) block_tables = block_tables.to(device=device) - new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device) - new_v = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device) + new_k = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device) + new_v = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device) # mock allocating blocks for the new k/v and update block tables - mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size) - # kv seq len = past kv seq len + seq len (1 during decoding stage) - kv_seq_lengths = past_kv_seq_lengths + 1 + for _ in range(n): + mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size) + past_kv_seq_lengths += 1 - return new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables + return new_k, new_v, k_cache, v_cache, past_kv_seq_lengths, block_tables @pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") @@ -64,12 +66,9 @@ def prepare_data( @pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32]) @pytest.mark.parametrize("num_kv_heads", [16]) @pytest.mark.parametrize("same_context_len", [True, False]) +@pytest.mark.parametrize("n_tokens", [1, 5]) def test_copy_kv_to_caches( - bsz: int, - block_size: int, - max_num_blocks_per_seq: int, - num_kv_heads: int, - same_context_len: bool, + bsz: int, block_size: int, max_num_blocks_per_seq: int, num_kv_heads: int, same_context_len: bool, n_tokens: int ): torch.manual_seed(123) torch.cuda.empty_cache() @@ -88,25 +87,49 @@ def test_copy_kv_to_caches( max_num_blocks_per_seq, same_context_len, max_seq_len, + n_tokens, device=device, dtype=dtype, ) - # k_cache_torch = k_cache.clone().detach() - # copy_to_cache(new_k, k_cache_torch, lengths=kv_seq_lengths, block_tables=block_tables, type="decoding") - copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables) - - past_kv_seq_len = kv_seq_lengths - 1 - target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size] - offsets_in_block = past_kv_seq_len % block_size - k_target = k_cache[target_block_ids, :, offsets_in_block, :] - k_source = new_k.squeeze() - v_target = v_cache[target_block_ids, :, offsets_in_block, :] - v_source = new_v.squeeze() + k_source = new_k.view(-1, new_k.size(-2), new_k.size(-1)) + v_source = new_v.view(-1, new_v.size(-2), new_v.size(-1)) + k_cache_copy = k_cache.detach().clone() + past_kv_seq_lengths = kv_seq_lengths - n_tokens + target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_lengths // block_size] + offsets_in_block = past_kv_seq_lengths % block_size + + # Copy k (or v) to k (or v) cache + copy_k_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables, n=n_tokens) + # Reshape target k from k cache to compare if matching with original tensor + # Mainly to handle cases of n_tokens > 1 + k_target = [] + for i in range(bsz): + block_table = block_tables[i] + curr_kv_len = past_kv_seq_lengths[i].item() + offset = offsets_in_block[i].item() + tokens_left = n_tokens + while tokens_left > 0: + tokens_to_fill = min(block_size - offset, tokens_left) + curr_block_id = block_table[curr_kv_len // block_size] + k_target.append(k_cache[curr_block_id, :, offset : offset + tokens_to_fill, :]) + curr_kv_len += tokens_to_fill + tokens_left -= tokens_to_fill + offset = 0 + k_target = torch.concat(k_target, dim=1).transpose(0, 1).contiguous() # [bsz * n, num_kv_heads, head_dim] assert k_target.shape == k_source.shape assert torch.equal(k_target, k_source) - assert v_target.shape == v_source.shape - assert torch.equal(v_target, v_source) + + if n_tokens == 1: + # Copy k and v to k/v caches + k_cache = k_cache_copy + copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables) + k_target = k_cache_copy[target_block_ids, :, offsets_in_block, :] + v_target = v_cache[target_block_ids, :, offsets_in_block, :] + assert k_target.shape == k_source.shape + assert torch.equal(k_target, k_source) + assert v_target.shape == v_source.shape + assert torch.equal(v_target, v_source) if __name__ == "__main__":