mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
169 lines
6.3 KiB
169 lines
6.3 KiB
import pytest
|
|
import torch
|
|
from packaging import version
|
|
|
|
from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
|
|
from colossalai.utils import get_current_device
|
|
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
|
generate_caches_and_block_tables_v2,
|
|
generate_caches_and_block_tables_v3,
|
|
mock_alloc_single_token,
|
|
)
|
|
|
|
try:
|
|
import triton # noqa
|
|
|
|
HAS_TRITON = True
|
|
except ImportError:
|
|
HAS_TRITON = False
|
|
print("please install triton from https://github.com/openai/triton")
|
|
|
|
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
|
|
|
HEAD_DIM = 32
|
|
|
|
|
|
def prepare_data(
|
|
bsz,
|
|
num_kv_heads,
|
|
head_dim,
|
|
block_size,
|
|
max_num_blocks_per_seq,
|
|
same_context_len,
|
|
max_seq_len,
|
|
n=1,
|
|
device="cuda",
|
|
dtype=torch.float16,
|
|
use_new_kcache_layout=False,
|
|
):
|
|
assert max_seq_len > n, "max_seq_len must be greater than n"
|
|
|
|
past_kv_seq_lengths = (
|
|
torch.tensor([max_seq_len - n for _ in range(bsz)], dtype=torch.int32, device=device)
|
|
if same_context_len
|
|
else torch.randint(low=1, high=max_seq_len - n, size=(bsz,), dtype=torch.int32, device=device)
|
|
)
|
|
num_tokens = torch.sum(past_kv_seq_lengths).item()
|
|
|
|
kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
|
|
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
|
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
|
|
|
|
if use_new_kcache_layout:
|
|
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
|
|
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
|
|
)
|
|
else:
|
|
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
|
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
|
|
)
|
|
block_tables = block_tables.to(device=device)
|
|
|
|
new_k = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device)
|
|
new_v = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device)
|
|
# mock allocating blocks for the new k/v and update block tables
|
|
for _ in range(n):
|
|
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
|
past_kv_seq_lengths += 1
|
|
|
|
return new_k, new_v, k_cache, v_cache, past_kv_seq_lengths, block_tables
|
|
|
|
|
|
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
|
@pytest.mark.parametrize("bsz", [7, 32])
|
|
@pytest.mark.parametrize("block_size", [16, 32, 64])
|
|
@pytest.mark.parametrize("max_num_blocks_per_seq", [16])
|
|
@pytest.mark.parametrize("num_kv_heads", [16])
|
|
@pytest.mark.parametrize("same_context_len", [True, False])
|
|
@pytest.mark.parametrize("n_tokens", [1, 5])
|
|
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
|
|
def test_copy_kv_to_caches(
|
|
bsz: int,
|
|
block_size: int,
|
|
max_num_blocks_per_seq: int,
|
|
num_kv_heads: int,
|
|
same_context_len: bool,
|
|
n_tokens: int,
|
|
use_new_kcache_layout: bool,
|
|
):
|
|
torch.manual_seed(123)
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
max_seq_len = block_size * max_num_blocks_per_seq
|
|
dtype = torch.float16
|
|
device = get_current_device()
|
|
|
|
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data(
|
|
bsz,
|
|
num_kv_heads,
|
|
HEAD_DIM,
|
|
block_size,
|
|
max_num_blocks_per_seq,
|
|
same_context_len,
|
|
max_seq_len,
|
|
n_tokens,
|
|
device=device,
|
|
dtype=dtype,
|
|
use_new_kcache_layout=use_new_kcache_layout,
|
|
)
|
|
k_source = new_k.view(-1, new_k.size(-2), new_k.size(-1))
|
|
v_source = new_v.view(-1, new_v.size(-2), new_v.size(-1))
|
|
k_cache_copy = k_cache.detach().clone()
|
|
past_kv_seq_lengths = kv_seq_lengths - n_tokens
|
|
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_lengths // block_size]
|
|
offsets_in_block = past_kv_seq_lengths % block_size
|
|
|
|
# Copy k (or v) to k (or v) cache
|
|
copy_k_to_blocked_cache(
|
|
new_k, k_cache, kv_seq_lengths, block_tables, n=n_tokens, use_new_kcache_layout=use_new_kcache_layout
|
|
)
|
|
# Reshape target k from k cache to compare if matching with original tensor
|
|
# Mainly to handle cases of n_tokens > 1
|
|
k_target = []
|
|
for i in range(bsz):
|
|
block_table = block_tables[i]
|
|
curr_kv_len = past_kv_seq_lengths[i].item()
|
|
offset = offsets_in_block[i].item()
|
|
tokens_left = n_tokens
|
|
while tokens_left > 0:
|
|
tokens_to_fill = min(block_size - offset, tokens_left)
|
|
curr_block_id = block_table[curr_kv_len // block_size]
|
|
if use_new_kcache_layout:
|
|
k_target.append(k_cache[curr_block_id, :, :, offset : offset + tokens_to_fill, :])
|
|
else:
|
|
k_target.append(k_cache[curr_block_id, :, offset : offset + tokens_to_fill, :])
|
|
curr_kv_len += tokens_to_fill
|
|
tokens_left -= tokens_to_fill
|
|
offset = 0
|
|
if use_new_kcache_layout:
|
|
k_target = torch.concat(k_target, dim=2).permute(2, 0, 1, 3).contiguous()
|
|
k_target = k_target.reshape(bsz * n_tokens, num_kv_heads, HEAD_DIM)
|
|
else:
|
|
k_target = torch.concat(k_target, dim=1).transpose(0, 1).contiguous() # [bsz * n, num_kv_heads, head_dim]
|
|
assert k_target.shape == k_source.shape
|
|
assert torch.equal(k_target, k_source)
|
|
|
|
if n_tokens == 1:
|
|
# Copy k and v to k/v caches
|
|
k_cache = k_cache_copy
|
|
copy_kv_to_blocked_cache(
|
|
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables, use_new_kcache_layout=use_new_kcache_layout
|
|
)
|
|
|
|
if use_new_kcache_layout:
|
|
k_target = k_cache[target_block_ids, :, :, offsets_in_block, :]
|
|
k_target = k_target.contiguous().reshape(bsz * n_tokens, num_kv_heads, HEAD_DIM)
|
|
else:
|
|
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
|
|
assert k_target.shape == k_source.shape
|
|
assert torch.equal(k_target, k_source)
|
|
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
|
|
assert v_target.shape == v_source.shape
|
|
assert torch.equal(v_target, v_source)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_copy_kv_to_caches(4, 32, 8, 16, True, n_tokens=1)
|