[kernel] Add KV cache copy kernel during decoding (#5261)

* add kv copy triton kernel during decoding stage

* add pytest and fix kernel

* fix test utilities

* revise kernel config

* add benchmark for kvcache copy
pull/5273/head
Yuanheng Zhao 2024-01-15 17:37:20 +08:00 committed by GitHub
parent 1ded7e81ef
commit fa85e02b3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 288 additions and 2 deletions

View File

@ -31,7 +31,7 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
1, 2, 0
)
elif type == "decoding":
assert len(source[0]) == 1, "seq_len should be equal to 1 when decoding."
assert source.size(1) == 1, "seq_len should be equal to 1 when decoding."
source = source.squeeze(1)
slot_idx = (lengths + block_size - 1) % block_size
for i in range(bsz):
@ -314,4 +314,4 @@ class PagedAttention:
):
return self.pad_decoding_forward(
q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1), k_cache, v_cache, lengths, block_tables
)
)

View File

@ -12,12 +12,14 @@ if HAS_TRITON:
from .flash_decoding import flash_decoding_fwd
from .fused_layernorm import layer_norm
from .gptq_triton import gptq_fused_linear_triton
from .kvcache_copy import copy_kv_to_blocked_cache
from .no_pad_rotary_embedding import rotary_embedding
from .softmax import softmax
__all__ = [
"context_attention_unpadded",
"flash_decoding_fwd",
"copy_kv_to_blocked_cache",
"softmax",
"layer_norm",
"gptq_fused_linear_triton",

View File

@ -0,0 +1,90 @@
import torch
import triton
import triton.language as tl
# Triton 2.1.0
@triton.jit
def _copy_to_kvcache_seqlen1_kernel(
KV, # K or V
KVCache, # KCache or VCache
BLOCK_TABLES,
context_lengths,
stride_kt,
stride_kh,
stride_kd,
stride_cacheb,
stride_cacheh,
stride_cached,
stride_cachebs,
stride_bts,
stride_btb,
block_size,
HEAD_DIM: tl.constexpr,
):
cur_seq_idx = tl.program_id(0)
cur_kv_head_idx = tl.program_id(1)
cur_kv_seq_len = tl.load(context_lengths + cur_seq_idx)
last_bt_block_idx = cur_kv_seq_len // block_size
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
offsets_in_last_block = (cur_kv_seq_len % block_size) * stride_cachebs
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
kv = tl.load(KV + offsets_kv)
offsets_kvcache = (
block_id * stride_cacheb
+ cur_kv_head_idx * stride_cacheh
+ offsets_dmodel * stride_cached
+ offsets_in_last_block
)
tl.store(KVCache + offsets_kvcache, kv)
return
# Used with blocked kv cache.
# Copy k or v to block k/v cache during decoding stage
def copy_kv_to_blocked_cache(
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_dim], k or v during decoding stage
k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size], blocked k or v cache (for now, the shapes of them are the same)
context_lengths: torch.Tensor, # [bsz], past kv seq len (not incorporating the current kv of length 1)
block_tables: torch.Tensor, # [bsz, max_blocks_per_sequence]
):
assert k.dim() == 4, "Unsupported shape of k (supposed to be used for decoding stage)"
assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)"
assert k.size(-1) == k_cache.size(-2), "Incompatible head dim"
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
bsz, _, num_kv_heads, head_dim = k.shape
assert context_lengths.shape[0] == block_tables.shape[0] == bsz, (
f"Got incompatible batch size (number of seqs):\n"
f" Conext lengths bsz {context_lengths.shape[0]}, Block tables bsz {block_tables.shape[0]}, "
f"batch size {bsz}"
)
# Modify if the shape of kv cahce is changed.
block_size = k_cache.size(-1)
# [bsz, 1, num_kv_heads, head_dim] -> [bsz, num_kv_heads, head_dim]
k = k.squeeze(dim=1)
num_warps = 8 if head_dim > 128 else 4
grid = (bsz, num_kv_heads)
_copy_to_kvcache_seqlen1_kernel[grid](
k,
k_cache,
block_tables,
context_lengths,
k.stride(0),
k.stride(1),
k.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
block_size,
HEAD_DIM=head_dim,
num_warps=num_warps,
)

View File

@ -100,3 +100,29 @@ def mock_alloc_block_table_and_kvcache(
block_id += 1
return block_tables
def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int):
"""Allocate 1 token on the block table for each seqs in block tables.
It won't change provided context_lengths
"""
# consider max_block_id as the last physical block allocated
# NOTE It assumes all the blocks preceding this block have been allocated
max_block_id = torch.max(block_tables).item()
# the indices on each block table representing the cache block to be allocated one more token
alloc_local_block_indices = context_lengths // block_size
# offsets of the token to be allocated on the target block (for each seq)
alloc_block_offsets = context_lengths % block_size
require_new_block = alloc_block_offsets == 0
new_block_ids = torch.arange(
max_block_id + 1,
max_block_id + 1 + require_new_block.sum(),
dtype=block_tables.dtype,
device=block_tables.device,
)
if new_block_ids.numel():
new_block_alloc_local_indices = alloc_local_block_indices[require_new_block]
block_tables[require_new_block, new_block_alloc_local_indices] = new_block_ids

View File

@ -0,0 +1,168 @@
import pytest
import torch
from packaging import version
from colossalai.inference.modeling.layers.attention import copy_to_cache
from colossalai.kernel.triton import copy_kv_to_blocked_cache
from colossalai.utils import get_current_device
from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache, mock_alloc_single_token
try:
import triton # noqa
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
def prepare_data(
bsz,
num_kv_heads,
head_dim,
block_size,
max_num_blocks_per_seq,
same_context_len,
max_seq_len,
device,
dtype=torch.float16,
):
if same_context_len:
# context_lengths in this test records the previous kv seq len
# (not incorporating the current input whose seq len is 1)
context_lengths = torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device)
else:
context_lengths = torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device)
num_tokens = torch.sum(context_lengths).item()
kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
kv = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k, v = torch.split(kv, [num_kv_heads, num_kv_heads], dim=-2)
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
# Mock allocation on block tables as well as blocked kv caches
block_tables = mock_alloc_block_table_and_kvcache(
k, v, k_cache, v_cache, context_lengths, bsz, max_num_blocks_per_seq, block_size
)
block_tables = block_tables.to(device=device)
new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
# mock allocating blocks for the new k/v and update block tables
mock_alloc_single_token(block_tables, context_lengths, block_size)
return new_k, k_cache, context_lengths, block_tables
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
@pytest.mark.parametrize("bsz", [4, 7, 32])
@pytest.mark.parametrize("block_size", [16, 32, 64])
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
@pytest.mark.parametrize("num_kv_heads", [16])
@pytest.mark.parametrize("same_context_len", [True, False])
def test_copy_kv_to_caches(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
head_dim = 128
max_seq_len = block_size * max_num_blocks_per_seq
dtype = torch.float16
device = get_current_device()
new_k, k_cache, context_lengths, block_tables = prepare_data(
bsz,
num_kv_heads,
head_dim,
block_size,
max_num_blocks_per_seq,
same_context_len,
max_seq_len,
device=device,
dtype=dtype,
)
copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables)
for seq_i in range(bsz):
ki = new_k[seq_i]
ki = ki.squeeze()
context_len_i = context_lengths[seq_i]
target_block_id = block_tables[seq_i, context_len_i // block_size]
offsets_in_block = context_len_i % block_size
target = k_cache[target_block_id, :, :, offsets_in_block]
orig = new_k[seq_i].squeeze(dim=0)
assert torch.equal(orig, target)
BATCH = 4
configs = [
triton.testing.Benchmark(
x_names=["PAST_KVLEN"],
x_vals=[2**i - 1 for i in range(8, 13)],
line_arg="provider",
line_vals=["torch_copy_func", "triton_copy_func"],
line_names=["torch_copy_func", "triton_copy_func"],
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name=f"kvcache_copy_decoding_stage-batch-{BATCH}",
args={"bsz": BATCH, "block_size": 16, "max_seq_len": 8192, "num_kv_heads": 16, "same_context_len": True},
)
]
@triton.testing.perf_report(configs)
def benchmark_kvcache_copy(
provider: str,
bsz: int,
block_size: int,
max_seq_len: int,
PAST_KVLEN: int, # maximum past kv length (unequal context lens in batch) or past kv len (equal context lens)
num_kv_heads: int,
same_context_len: bool,
):
warmup = 10
rep = 100
head_dim = 128
dtype = torch.float16
device = get_current_device()
assert PAST_KVLEN < max_seq_len, "Assigned maximum past kv length must be smaller or equal to maximum seq len"
new_k, k_cache, context_lengths, block_tables = prepare_data(
bsz,
num_kv_heads,
head_dim,
block_size,
max_seq_len // block_size,
same_context_len,
PAST_KVLEN,
device=device,
dtype=dtype,
)
if provider == "torch_copy_func":
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
elif provider == "triton_copy_func":
fn = lambda: copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables)
else:
raise ValueError("Undefined provider.")
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if __name__ == "__main__":
test_copy_kv_to_caches(4, 32, 8, 16, False)
# benchmark_kvcache_copy.run(save_path=".")