mirror of https://github.com/hpcaitech/ColossalAI
[kernel] Revise KVCache copy triton kernel API (#5273)
* [kernel/fix] revise kvcache copy kernel api * fix benchmarkpull/5264/head
parent
d8db500efc
commit
0f2b46a41c
|
@ -25,11 +25,11 @@ def _copy_to_kvcache_seqlen1_kernel(
|
|||
cur_seq_idx = tl.program_id(0)
|
||||
cur_kv_head_idx = tl.program_id(1)
|
||||
|
||||
cur_kv_seq_len = tl.load(context_lengths + cur_seq_idx)
|
||||
last_bt_block_idx = cur_kv_seq_len // block_size
|
||||
past_kv_seq_len = tl.load(context_lengths + cur_seq_idx) - 1
|
||||
last_bt_block_idx = past_kv_seq_len // block_size
|
||||
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
|
||||
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
|
||||
offsets_in_last_block = (cur_kv_seq_len % block_size) * stride_cachebs
|
||||
offsets_in_last_block = (past_kv_seq_len % block_size) * stride_cachebs
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
||||
kv = tl.load(KV + offsets_kv)
|
||||
|
@ -43,23 +43,30 @@ def _copy_to_kvcache_seqlen1_kernel(
|
|||
return
|
||||
|
||||
|
||||
# Used with blocked kv cache.
|
||||
# Copy k or v to block k/v cache during decoding stage
|
||||
def copy_kv_to_blocked_cache(
|
||||
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_dim], k or v during decoding stage
|
||||
k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size], blocked k or v cache (for now, the shapes of them are the same)
|
||||
context_lengths: torch.Tensor, # [bsz], past kv seq len (not incorporating the current kv of length 1)
|
||||
block_tables: torch.Tensor, # [bsz, max_blocks_per_sequence]
|
||||
k: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
kv_lengths: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Copy keys or values to the blocked key/value cache during decoding stage.
|
||||
|
||||
Parameters:
|
||||
- k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
|
||||
- k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] - Blocked key or value cache.
|
||||
- kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
|
||||
- block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
|
||||
"""
|
||||
assert k.dim() == 4, "Unsupported shape of k (supposed to be used for decoding stage)"
|
||||
assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)"
|
||||
assert k.size(-1) == k_cache.size(-2), "Incompatible head dim"
|
||||
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
|
||||
bsz, _, num_kv_heads, head_dim = k.shape
|
||||
assert context_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
||||
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
||||
f"Got incompatible batch size (number of seqs):\n"
|
||||
f" Conext lengths bsz {context_lengths.shape[0]}, Block tables bsz {block_tables.shape[0]}, "
|
||||
f"batch size {bsz}"
|
||||
f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; "
|
||||
f" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}"
|
||||
)
|
||||
|
||||
# Modify if the shape of kv cahce is changed.
|
||||
|
@ -74,7 +81,7 @@ def copy_kv_to_blocked_cache(
|
|||
k,
|
||||
k_cache,
|
||||
block_tables,
|
||||
context_lengths,
|
||||
kv_lengths,
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
|
|
|
@ -30,12 +30,12 @@ def prepare_data(
|
|||
dtype=torch.float16,
|
||||
):
|
||||
if same_context_len:
|
||||
# context_lengths in this test records the previous kv seq len
|
||||
# past_kv_seq_lengths in this test records the previous kv seq len
|
||||
# (not incorporating the current input whose seq len is 1)
|
||||
context_lengths = torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||
past_kv_seq_lengths = torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||
else:
|
||||
context_lengths = torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device)
|
||||
num_tokens = torch.sum(context_lengths).item()
|
||||
past_kv_seq_lengths = torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device)
|
||||
num_tokens = torch.sum(past_kv_seq_lengths).item()
|
||||
|
||||
kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
|
||||
kv = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
|
@ -46,15 +46,18 @@ def prepare_data(
|
|||
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
||||
# Mock allocation on block tables as well as blocked kv caches
|
||||
block_tables = mock_alloc_block_table_and_kvcache(
|
||||
k, v, k_cache, v_cache, context_lengths, bsz, max_num_blocks_per_seq, block_size
|
||||
k, v, k_cache, v_cache, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
block_tables = block_tables.to(device=device)
|
||||
|
||||
new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
|
||||
# mock allocating blocks for the new k/v and update block tables
|
||||
mock_alloc_single_token(block_tables, context_lengths, block_size)
|
||||
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
||||
|
||||
return new_k, k_cache, context_lengths, block_tables
|
||||
# kv seq len = past kv seq len + seq len (1 during decoding stage)
|
||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||
|
||||
return new_k, k_cache, kv_seq_lengths, block_tables
|
||||
|
||||
|
||||
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||
|
@ -80,7 +83,7 @@ def test_copy_kv_to_caches(
|
|||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
|
||||
new_k, k_cache, context_lengths, block_tables = prepare_data(
|
||||
new_k, k_cache, kv_seq_lengths, block_tables = prepare_data(
|
||||
bsz,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
|
@ -91,25 +94,24 @@ def test_copy_kv_to_caches(
|
|||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables)
|
||||
copy_kv_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables)
|
||||
|
||||
for seq_i in range(bsz):
|
||||
ki = new_k[seq_i]
|
||||
ki = ki.squeeze()
|
||||
context_len_i = context_lengths[seq_i]
|
||||
target_block_id = block_tables[seq_i, context_len_i // block_size]
|
||||
offsets_in_block = context_len_i % block_size
|
||||
past_kv_seq_len = kv_seq_lengths[seq_i] - 1
|
||||
target_block_id = block_tables[seq_i, past_kv_seq_len // block_size]
|
||||
offsets_in_block = past_kv_seq_len % block_size
|
||||
target = k_cache[target_block_id, :, :, offsets_in_block]
|
||||
orig = new_k[seq_i].squeeze(dim=0)
|
||||
assert torch.equal(orig, target)
|
||||
|
||||
|
||||
BATCH = 4
|
||||
BATCH = 16
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["PAST_KVLEN"],
|
||||
x_vals=[2**i - 1 for i in range(8, 13)],
|
||||
x_names=["KV_SEQ_LEN"],
|
||||
x_vals=[2**i for i in range(8, 13)],
|
||||
line_arg="provider",
|
||||
line_vals=["torch_copy_func", "triton_copy_func"],
|
||||
line_names=["torch_copy_func", "triton_copy_func"],
|
||||
|
@ -127,7 +129,7 @@ def benchmark_kvcache_copy(
|
|||
bsz: int,
|
||||
block_size: int,
|
||||
max_seq_len: int,
|
||||
PAST_KVLEN: int, # maximum past kv length (unequal context lens in batch) or past kv len (equal context lens)
|
||||
KV_SEQ_LEN: int, # maximum past kv length (unequal context lens in batch) or past kv len (equal context lens)
|
||||
num_kv_heads: int,
|
||||
same_context_len: bool,
|
||||
):
|
||||
|
@ -138,7 +140,7 @@ def benchmark_kvcache_copy(
|
|||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
|
||||
assert PAST_KVLEN < max_seq_len, "Assigned maximum past kv length must be smaller or equal to maximum seq len"
|
||||
assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"
|
||||
|
||||
new_k, k_cache, context_lengths, block_tables = prepare_data(
|
||||
bsz,
|
||||
|
@ -147,7 +149,7 @@ def benchmark_kvcache_copy(
|
|||
block_size,
|
||||
max_seq_len // block_size,
|
||||
same_context_len,
|
||||
PAST_KVLEN,
|
||||
KV_SEQ_LEN,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -164,5 +166,5 @@ def benchmark_kvcache_copy(
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_copy_kv_to_caches(4, 32, 8, 16, False)
|
||||
# benchmark_kvcache_copy.run(save_path=".")
|
||||
test_copy_kv_to_caches(4, 32, 8, 16, True)
|
||||
# benchmark_kvcache_copy.run(save_path=".", print_data=True)
|
||||
|
|
Loading…
Reference in New Issue