diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 355140bc1..9de3f040d 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -301,8 +301,9 @@ class NopadLlamaAttention(LlamaAttention): sm_scale=sm_scale, ) else: - copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables) - copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables) + copy_kv_to_blocked_cache( + key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables + ) attn_output = flash_decoding_attention( q=query_states, k_cache=k_cache, diff --git a/colossalai/inference/modeling/models/padding_llama.py b/colossalai/inference/modeling/models/padding_llama.py index 2eac07d76..63050cd6d 100644 --- a/colossalai/inference/modeling/models/padding_llama.py +++ b/colossalai/inference/modeling/models/padding_llama.py @@ -356,8 +356,9 @@ class PadLlamaAttention(LlamaAttention): if attention_mask is not None: attn_output = pad_input(attn_output, indices, bsz, q_len) else: - copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables) - copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables) + copy_kv_to_blocked_cache( + key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables + ) attn_output = flash_decoding_attention( q=query_states, k_cache=k_cache, diff --git a/colossalai/kernel/triton/kvcache_copy.py b/colossalai/kernel/triton/kvcache_copy.py index 1aaeb6830..4f056acf6 100644 --- a/colossalai/kernel/triton/kvcache_copy.py +++ b/colossalai/kernel/triton/kvcache_copy.py @@ -6,17 +6,26 @@ import triton.language as tl # Triton 2.1.0 @triton.jit def _copy_to_kvcache_seqlen1_kernel( - KV, # K or V - KVCache, # KCache or VCache + K, # K + V, # V + KCache, # KCache + VCache, # VCache BLOCK_TABLES, context_lengths, stride_kt, stride_kh, stride_kd, - stride_cacheb, - stride_cacheh, - stride_cachebs, - stride_cached, + stride_vt, + stride_vh, + stride_vd, + stride_cachekb, + stride_cachekh, + stride_cachekbs, + stride_cachekd, + stride_cachevb, + stride_cachevh, + stride_cachevbs, + stride_cachevd, stride_bts, stride_btb, block_size, @@ -32,20 +41,33 @@ def _copy_to_kvcache_seqlen1_kernel( offsets_in_last_block = past_kv_seq_len % block_size offsets_dmodel = tl.arange(0, HEAD_DIM) offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd - kv = tl.load(KV + offsets_kv) + + k = tl.load(K + offsets_kv) + v = tl.load(V + offsets_kv) + offsets_kvcache = ( - block_id * stride_cacheb - + cur_kv_head_idx * stride_cacheh - + offsets_in_last_block * stride_cachebs - + offsets_dmodel * stride_cached + block_id * stride_cachekb + + cur_kv_head_idx * stride_cachekh + + offsets_in_last_block * stride_cachekbs + + offsets_dmodel * stride_cachekd ) - tl.store(KVCache + offsets_kvcache, kv) + offsets_kvcache = ( + block_id * stride_cachevb + + cur_kv_head_idx * stride_cachevh + + offsets_in_last_block * stride_cachevbs + + offsets_dmodel * stride_cachevd + ) + + tl.store(KCache + offsets_kvcache, k) + tl.store(VCache + offsets_kvcache, v) return def copy_kv_to_blocked_cache( k: torch.Tensor, + v: torch.Tensor, k_cache: torch.Tensor, + v_cache: torch.Tensor, kv_lengths: torch.Tensor, block_tables: torch.Tensor, ): @@ -53,16 +75,23 @@ def copy_kv_to_blocked_cache( Copy keys or values to the blocked key/value cache during decoding stage. Args: - k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1. - k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache. + k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys during decoding with seq len 1. + v (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Values during decoding with seq len 1. + k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key cache. + v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked value cache. kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence. block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence. """ assert k.size(-1) == k_cache.size(-1), "Incompatible head dim" assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache." - k = k.squeeze(1) if k.dim() == 4 else k assert k.dim() == 3, f"Incompatible k dim {k.dim()}" + + assert v.size(-1) == v_cache.size(-1), "Incompatible head dim" + assert v.dtype == v_cache.dtype, "Expected consistent dtype for tensor and cache." + v = v.squeeze(1) if v.dim() == 4 else v + assert v.dim() == 3, f"Incompatible v dim {v.dim()}" + bsz, num_kv_heads, head_dim = k.shape assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, ( @@ -75,20 +104,28 @@ def copy_kv_to_blocked_cache( block_size = k_cache.size(-2) num_warps = 8 if head_dim > 128 else 4 - grid = (bsz, num_kv_heads) _copy_to_kvcache_seqlen1_kernel[grid]( k, + v, k_cache, + v_cache, block_tables, kv_lengths, k.stride(0), k.stride(1), k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), k_cache.stride(0), k_cache.stride(1), k_cache.stride(2), k_cache.stride(3), + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), block_tables.stride(0), block_tables.stride(1), block_size, diff --git a/tests/test_infer/test_ops/triton/test_kvcache_copy.py b/tests/test_infer/test_ops/triton/test_kvcache_copy.py index 5612f2bd9..53475270e 100644 --- a/tests/test_infer/test_ops/triton/test_kvcache_copy.py +++ b/tests/test_infer/test_ops/triton/test_kvcache_copy.py @@ -44,18 +44,19 @@ def prepare_data( kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2) - k_cache, _, block_tables = generate_caches_and_block_tables_v2( + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device ) block_tables = block_tables.to(device=device) new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device) + new_v = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device) # mock allocating blocks for the new k/v and update block tables mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size) # kv seq len = past kv seq len + seq len (1 during decoding stage) kv_seq_lengths = past_kv_seq_lengths + 1 - return new_k, k_cache, kv_seq_lengths, block_tables + return new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables @pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") @@ -80,7 +81,7 @@ def test_copy_kv_to_caches( dtype = torch.float16 device = get_current_device() - new_k, k_cache, kv_seq_lengths, block_tables = prepare_data( + new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data( bsz, num_kv_heads, HEAD_DIM, @@ -93,16 +94,20 @@ def test_copy_kv_to_caches( ) # k_cache_torch = k_cache.clone().detach() # copy_to_cache(new_k, k_cache_torch, lengths=kv_seq_lengths, block_tables=block_tables, type="decoding") - copy_kv_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables) + copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables) past_kv_seq_len = kv_seq_lengths - 1 target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size] offsets_in_block = past_kv_seq_len % block_size - target = k_cache[target_block_ids, :, offsets_in_block, :] - source = new_k.squeeze() + k_target = k_cache[target_block_ids, :, offsets_in_block, :] + k_source = new_k.squeeze() + v_target = v_cache[target_block_ids, :, offsets_in_block, :] + v_source = new_v.squeeze() - assert target.shape == source.shape - assert torch.equal(target, source) + assert k_target.shape == k_source.shape + assert torch.equal(k_target, k_source) + assert v_target.shape == v_source.shape + assert torch.equal(v_target, v_source) # target_torch = k_cache_copy[target_block_ids, :, offsets_in_block, :] # assert target_torch.shape == source.shape # assert torch.equal(target_torch, source) @@ -143,7 +148,7 @@ def benchmark_kvcache_copy( assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len" - new_k, k_cache, context_lengths, block_tables = prepare_data( + new_k, new_v, k_cache, v_cache, context_lengths, block_tables = prepare_data( bsz, num_kv_heads, HEAD_DIM, @@ -156,10 +161,11 @@ def benchmark_kvcache_copy( ) quantiles = [0.5, 0.2, 0.8] + # TODO copy_to_cache needs to support copying both k and v at the same time in the future. if provider == "torch_copy_func": fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding") if provider == "triton_copy_func": - fn = lambda: copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables) + fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables) ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) return ms, min_ms, max_ms