mirror of https://github.com/hpcaitech/ColossalAI
[Inference/opt] Fused KVCahce Memcopy (#5374)
* fused kv memcopy * add TODO in test_kvcache_copy.pypull/5337/head
parent
58740b5f68
commit
6fb4bcbb24
|
@ -301,8 +301,9 @@ class NopadLlamaAttention(LlamaAttention):
|
|||
sm_scale=sm_scale,
|
||||
)
|
||||
else:
|
||||
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
copy_kv_to_blocked_cache(
|
||||
key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables
|
||||
)
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
|
|
|
@ -356,8 +356,9 @@ class PadLlamaAttention(LlamaAttention):
|
|||
if attention_mask is not None:
|
||||
attn_output = pad_input(attn_output, indices, bsz, q_len)
|
||||
else:
|
||||
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
copy_kv_to_blocked_cache(
|
||||
key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables
|
||||
)
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
|
|
|
@ -6,17 +6,26 @@ import triton.language as tl
|
|||
# Triton 2.1.0
|
||||
@triton.jit
|
||||
def _copy_to_kvcache_seqlen1_kernel(
|
||||
KV, # K or V
|
||||
KVCache, # KCache or VCache
|
||||
K, # K
|
||||
V, # V
|
||||
KCache, # KCache
|
||||
VCache, # VCache
|
||||
BLOCK_TABLES,
|
||||
context_lengths,
|
||||
stride_kt,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_cacheb,
|
||||
stride_cacheh,
|
||||
stride_cachebs,
|
||||
stride_cached,
|
||||
stride_vt,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_cachekb,
|
||||
stride_cachekh,
|
||||
stride_cachekbs,
|
||||
stride_cachekd,
|
||||
stride_cachevb,
|
||||
stride_cachevh,
|
||||
stride_cachevbs,
|
||||
stride_cachevd,
|
||||
stride_bts,
|
||||
stride_btb,
|
||||
block_size,
|
||||
|
@ -32,20 +41,33 @@ def _copy_to_kvcache_seqlen1_kernel(
|
|||
offsets_in_last_block = past_kv_seq_len % block_size
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
||||
kv = tl.load(KV + offsets_kv)
|
||||
|
||||
k = tl.load(K + offsets_kv)
|
||||
v = tl.load(V + offsets_kv)
|
||||
|
||||
offsets_kvcache = (
|
||||
block_id * stride_cacheb
|
||||
+ cur_kv_head_idx * stride_cacheh
|
||||
+ offsets_in_last_block * stride_cachebs
|
||||
+ offsets_dmodel * stride_cached
|
||||
block_id * stride_cachekb
|
||||
+ cur_kv_head_idx * stride_cachekh
|
||||
+ offsets_in_last_block * stride_cachekbs
|
||||
+ offsets_dmodel * stride_cachekd
|
||||
)
|
||||
tl.store(KVCache + offsets_kvcache, kv)
|
||||
offsets_kvcache = (
|
||||
block_id * stride_cachevb
|
||||
+ cur_kv_head_idx * stride_cachevh
|
||||
+ offsets_in_last_block * stride_cachevbs
|
||||
+ offsets_dmodel * stride_cachevd
|
||||
)
|
||||
|
||||
tl.store(KCache + offsets_kvcache, k)
|
||||
tl.store(VCache + offsets_kvcache, v)
|
||||
return
|
||||
|
||||
|
||||
def copy_kv_to_blocked_cache(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
kv_lengths: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
):
|
||||
|
@ -53,16 +75,23 @@ def copy_kv_to_blocked_cache(
|
|||
Copy keys or values to the blocked key/value cache during decoding stage.
|
||||
|
||||
Args:
|
||||
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
|
||||
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
|
||||
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys during decoding with seq len 1.
|
||||
v (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Values during decoding with seq len 1.
|
||||
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key cache.
|
||||
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked value cache.
|
||||
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
|
||||
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
|
||||
"""
|
||||
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
|
||||
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
|
||||
|
||||
k = k.squeeze(1) if k.dim() == 4 else k
|
||||
assert k.dim() == 3, f"Incompatible k dim {k.dim()}"
|
||||
|
||||
assert v.size(-1) == v_cache.size(-1), "Incompatible head dim"
|
||||
assert v.dtype == v_cache.dtype, "Expected consistent dtype for tensor and cache."
|
||||
v = v.squeeze(1) if v.dim() == 4 else v
|
||||
assert v.dim() == 3, f"Incompatible v dim {v.dim()}"
|
||||
|
||||
bsz, num_kv_heads, head_dim = k.shape
|
||||
|
||||
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
||||
|
@ -75,20 +104,28 @@ def copy_kv_to_blocked_cache(
|
|||
block_size = k_cache.size(-2)
|
||||
|
||||
num_warps = 8 if head_dim > 128 else 4
|
||||
|
||||
grid = (bsz, num_kv_heads)
|
||||
_copy_to_kvcache_seqlen1_kernel[grid](
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
kv_lengths,
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
block_size,
|
||||
|
|
|
@ -44,18 +44,19 @@ def prepare_data(
|
|||
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
|
||||
|
||||
k_cache, _, block_tables = generate_caches_and_block_tables_v2(
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
|
||||
)
|
||||
block_tables = block_tables.to(device=device)
|
||||
|
||||
new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
|
||||
new_v = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
|
||||
# mock allocating blocks for the new k/v and update block tables
|
||||
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
||||
# kv seq len = past kv seq len + seq len (1 during decoding stage)
|
||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||
|
||||
return new_k, k_cache, kv_seq_lengths, block_tables
|
||||
return new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables
|
||||
|
||||
|
||||
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||
|
@ -80,7 +81,7 @@ def test_copy_kv_to_caches(
|
|||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
|
||||
new_k, k_cache, kv_seq_lengths, block_tables = prepare_data(
|
||||
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data(
|
||||
bsz,
|
||||
num_kv_heads,
|
||||
HEAD_DIM,
|
||||
|
@ -93,16 +94,20 @@ def test_copy_kv_to_caches(
|
|||
)
|
||||
# k_cache_torch = k_cache.clone().detach()
|
||||
# copy_to_cache(new_k, k_cache_torch, lengths=kv_seq_lengths, block_tables=block_tables, type="decoding")
|
||||
copy_kv_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables)
|
||||
copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)
|
||||
|
||||
past_kv_seq_len = kv_seq_lengths - 1
|
||||
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
|
||||
offsets_in_block = past_kv_seq_len % block_size
|
||||
target = k_cache[target_block_ids, :, offsets_in_block, :]
|
||||
source = new_k.squeeze()
|
||||
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
|
||||
k_source = new_k.squeeze()
|
||||
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
|
||||
v_source = new_v.squeeze()
|
||||
|
||||
assert target.shape == source.shape
|
||||
assert torch.equal(target, source)
|
||||
assert k_target.shape == k_source.shape
|
||||
assert torch.equal(k_target, k_source)
|
||||
assert v_target.shape == v_source.shape
|
||||
assert torch.equal(v_target, v_source)
|
||||
# target_torch = k_cache_copy[target_block_ids, :, offsets_in_block, :]
|
||||
# assert target_torch.shape == source.shape
|
||||
# assert torch.equal(target_torch, source)
|
||||
|
@ -143,7 +148,7 @@ def benchmark_kvcache_copy(
|
|||
|
||||
assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"
|
||||
|
||||
new_k, k_cache, context_lengths, block_tables = prepare_data(
|
||||
new_k, new_v, k_cache, v_cache, context_lengths, block_tables = prepare_data(
|
||||
bsz,
|
||||
num_kv_heads,
|
||||
HEAD_DIM,
|
||||
|
@ -156,10 +161,11 @@ def benchmark_kvcache_copy(
|
|||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
# TODO copy_to_cache needs to support copying both k and v at the same time in the future.
|
||||
if provider == "torch_copy_func":
|
||||
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
|
||||
if provider == "triton_copy_func":
|
||||
fn = lambda: copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables)
|
||||
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
||||
return ms, min_ms, max_ms
|
||||
|
|
Loading…
Reference in New Issue