[Inference]Add CUDA KVCache Kernel (#5406)

* add cuda KVCache kernel

* annotation benchmark_kvcache_copy

* add use cuda

* fix import path

* move benchmark scripts to example/

* rm benchmark codes in test_kv_cache_memcpy.py

* rm redundancy codes

* rm redundancy codes

* pr was modified according to the review
pull/5408/head
yuehuayingxueluo 2024-02-28 14:36:50 +08:00 committed by GitHub
parent 19061188c3
commit 600881a8ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 348 additions and 75 deletions

View File

@ -13,6 +13,7 @@ from transformers.models.llama.modeling_llama import (
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
decoding_fused_rotary_embedding,
@ -22,6 +23,8 @@ from colossalai.kernel.triton import (
)
from colossalai.logging import get_dist_logger
inference_ops = InferenceOpsLoader().load()
logger = get_dist_logger(__name__)
try:
@ -74,6 +77,12 @@ def llama_model_forward(
sequence_lengths = batch.get_sequence_lengths()
batch_size = batch.current_batch_size
kv_seq_len = sequence_lengths.max().item()
use_cuda_kernel = True
# NOTE: After testing, the performance of this configuration is relatively good. With updates
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
# selection should be conducted.
if batch_size >= 32 and kv_seq_len > 512:
use_cuda_kernel = False
hidden_states = self.embed_tokens(input_ids)
@ -107,6 +116,7 @@ def llama_model_forward(
output_tensor=output_tensor,
norm_output=norm_output,
sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
)
if batch.is_prompts:
@ -134,6 +144,7 @@ def llama_decoder_layer_forward(
output_tensor: torch.Tensor = None,
norm_output: torch.Tensor = None,
sm_scale: int = None,
use_cuda_kernel: bool = True,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""This function will replace the forward function of LlamaDecoderLayer.
@ -153,6 +164,7 @@ def llama_decoder_layer_forward(
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
"""
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
@ -169,6 +181,7 @@ def llama_decoder_layer_forward(
fd_inter_tensor=fd_inter_tensor,
output_tensor=output_tensor,
sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
)
# Fully Connected
@ -252,6 +265,7 @@ class NopadLlamaAttention(LlamaAttention):
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
use_cuda_kernel: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Args:
@ -268,6 +282,7 @@ class NopadLlamaAttention(LlamaAttention):
storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
"""
if self.num_heads != self.num_key_value_heads:
@ -283,7 +298,6 @@ class NopadLlamaAttention(LlamaAttention):
)
block_size = k_cache.size(-2)
if is_prompts:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
@ -300,17 +314,23 @@ class NopadLlamaAttention(LlamaAttention):
sm_scale=sm_scale,
)
else:
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
if use_cuda_kernel:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
)
else:
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,

View File

@ -8,6 +8,7 @@ from .extensions import (
FlashAttentionNpuExtension,
FlashAttentionXformersCudaExtension,
FusedOptimizerCudaExtension,
InferenceOpsCudaExtension,
LayerNormCudaExtension,
MoeCudaExtension,
ScaledMaskedSoftmaxCudaExtension,
@ -21,6 +22,7 @@ __all__ = [
"LayerNormLoader",
"MoeLoader",
"FusedOptimizerLoader",
"InferenceOpsLoader",
"ScaledMaskedSoftmaxLoader",
"ScaledUpperTriangleMaskedSoftmaxLoader",
]
@ -97,6 +99,10 @@ class FusedOptimizerLoader(KernelLoader):
REGISTRY = [FusedOptimizerCudaExtension]
class InferenceOpsLoader(KernelLoader):
REGISTRY = [InferenceOpsCudaExtension]
class ScaledMaskedSoftmaxLoader(KernelLoader):
REGISTRY = [ScaledMaskedSoftmaxCudaExtension]

View File

@ -0,0 +1,80 @@
import torch
from colossalai.inference.modeling.layers.attention import copy_to_cache
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_kv_to_blocked_cache
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data
try:
import triton # noqa
except ImportError:
print("please install triton from https://github.com/openai/triton")
inference_ops = InferenceOpsLoader().load()
HEAD_DIM = 4
BATCH = 16
BLOCK_SIZE = 32
SAME_LEN = True
WARM_UPS = 10
REPS = 100
configs = [
triton.testing.Benchmark(
x_names=["KV_SEQ_LEN"],
x_vals=[2**i for i in range(8, 13)],
line_arg="provider",
line_vals=["torch_copy_func", "triton_copy_func", "cuda_copy_func"],
line_names=["torch_copy_func", "triton_copy_func", "cuda_copy_func"],
styles=[("red", "-"), ("blue", "-"), ("green", "-")],
ylabel="ms",
plot_name=f"kvcache_copy_decoding_stage-batch-{BATCH}",
args={"bsz": BATCH, "block_size": 16, "max_seq_len": 8192, "num_kv_heads": 16, "same_context_len": True},
)
]
@triton.testing.perf_report(configs)
def benchmark_kvcache_copy(
provider: str,
bsz: int,
block_size: int,
max_seq_len: int,
KV_SEQ_LEN: int, # maximum past kv length (unequal context lens in batch) or past kv len (equal context lens)
num_kv_heads: int,
same_context_len: bool,
):
dtype = torch.float32
device = get_current_device()
assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"
new_k, new_v, k_cache, v_cache, context_lengths, block_tables = prepare_data(
bsz,
num_kv_heads,
HEAD_DIM,
block_size,
max_seq_len // block_size,
same_context_len,
KV_SEQ_LEN,
device=device,
dtype=dtype,
)
quantiles = [0.5, 0.2, 0.8]
# TODO copy_to_cache needs to support copying both k and v at the same time in the future.
if provider == "torch_copy_func":
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
elif provider == "triton_copy_func":
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
elif provider == "cuda_copy_func":
new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k
new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v
fn = lambda: inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
return ms, min_ms, max_ms
if __name__ == "__main__":
benchmark_kvcache_copy.run(save_path=".", print_data=True)

View File

@ -4,6 +4,7 @@ from .flash_attention import (
FlashAttentionNpuExtension,
FlashAttentionXformersCudaExtension,
)
from .inference import InferenceOpsCudaExtension
from .layernorm import LayerNormCudaExtension
from .moe import MoeCudaExtension
from .optimizer import FusedOptimizerCudaExtension
@ -15,6 +16,7 @@ ALL_EXTENSIONS = [
LayerNormCudaExtension,
MoeCudaExtension,
FusedOptimizerCudaExtension,
InferenceOpsCudaExtension,
ScaledMaskedSoftmaxCudaExtension,
ScaledUpperTriangleMaskedSoftmaxCudaExtension,
FlashAttentionDaoCudaExtension,
@ -28,6 +30,7 @@ __all__ = [
"LayerNormCudaExtension",
"MoeCudaExtension",
"FusedOptimizerCudaExtension",
"InferenceOpsCudaExtension",
"ScaledMaskedSoftmaxCudaExtension",
"ScaledUpperTriangleMaskedSoftmaxCudaExtension",
"FlashAttentionDaoCudaExtension",

View File

@ -0,0 +1,15 @@
#include <torch/extension.h>
void decode_kv_cache_memcpy(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor&
value_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables); // [batch_size, max_seq_len]
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
"Copy the GPU memory of kvcache during the decode stage.");
}

View File

@ -0,0 +1,90 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <stdio.h>
#include "type_shim.h"
template<typename scalar_t>
__global__ void decode_kv_cache_memcpy_kernel(
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache,
scalar_t* __restrict__ value_cache,
const int* __restrict__ sequence_lengths,
const int* __restrict__ block_tables,
const int num_heads,
const int head_size,
const int block_size,
const int key_stride,
const int value_stride,
const int block_table_stride
)
{
const int seq_id = blockIdx.x;
const int seq_len = sequence_lengths[seq_id] - 1;
const int seq_id_in_block_table = seq_len / block_size;
const int block_offset = seq_len % block_size;
const int block_id = block_tables[seq_id * block_table_stride + seq_id_in_block_table];
const int hidden_size = num_heads * head_size;
if ( block_id < 0 ) {
return ;
}
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
const int head_id = i / head_size;
const int head_offset = i % head_size;
const int key_src_id = seq_id * key_stride + i;
const int value_src_id = seq_id * value_stride + i;
const int target_src_id = block_id * hidden_size * block_size
+ head_id * block_size * head_size
+ block_offset * head_size + head_offset;
key_cache[target_src_id] = key[key_src_id];
value_cache[target_src_id] = value[value_src_id];
}
}
void decode_kv_cache_memcpy(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables) // [batch_size, max_seq_len]
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(2);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
int block_table_stride = block_tables.stride(0);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
DISPATCH_FLOAT_HALF_AND_BFLOAT(
key.scalar_type(),
"decode_kv_cache_memcpy",
decode_kv_cache_memcpy_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
num_heads,
head_size,
block_size,
key_stride,
value_stride,
block_table_stride
);)
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -24,6 +24,27 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Float: { \

View File

@ -1,7 +1,10 @@
import os
import time
from abc import abstractmethod
from pathlib import Path
from typing import List
from .base_extension import _Extension
from .cpp_extension import _CppExtension
from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list

View File

@ -0,0 +1,3 @@
from .inference_ops_cuda import InferenceOpsCudaExtension
__all__ = ["InferenceOpsCudaExtension"]

View File

@ -0,0 +1,30 @@
from ..cuda_extension import _CudaExtension
from ..utils import get_cuda_cc_flag
class InferenceOpsCudaExtension(_CudaExtension):
def __init__(self):
super().__init__(name="inference_ops_cuda")
def sources_files(self):
ret = [
self.csrc_abs_path(fname)
for fname in [
"cuda/colossal_inference_C_frontend.cpp",
"cuda/decode_kv_cache_memcpy_kernel.cu",
]
]
return ret
def include_dirs(self):
ret = [self.get_cuda_home_include()]
return ret
def cxx_flags(self):
version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
return ["-O3"] + version_dependent_macros
def nvcc_flags(self):
extra_cuda_flags = ["-lineinfo"]
extra_cuda_flags.extend(get_cuda_cc_flag())
return ["-O3", "--use_fast_math"] + extra_cuda_flags

View File

View File

@ -0,0 +1,65 @@
import pytest
import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data
inference_ops = InferenceOpsLoader().load()
HEAD_DIM = 4
@pytest.mark.parametrize("bsz", [4, 7, 32])
@pytest.mark.parametrize("block_size", [16, 32, 64])
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
@pytest.mark.parametrize("num_kv_heads", [16])
@pytest.mark.parametrize("same_context_len", [True, False])
def test_copy_kv_to_caches(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
max_seq_len = block_size * max_num_blocks_per_seq
dtype = torch.float32
device = get_current_device()
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data(
bsz,
num_kv_heads,
HEAD_DIM,
block_size,
max_num_blocks_per_seq,
same_context_len,
max_seq_len,
device=device,
dtype=dtype,
)
new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k
new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)
past_kv_seq_len = kv_seq_lengths - 1
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
offsets_in_block = past_kv_seq_len % block_size
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
k_source = new_k.squeeze()
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
v_source = new_v.squeeze()
assert k_target.shape == k_source.shape
assert torch.equal(k_target, k_source)
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source)
if __name__ == "__main__":
test_copy_kv_to_caches(4, 32, 8, 16, True)

View File

@ -2,7 +2,6 @@ import pytest
import torch
from packaging import version
from colossalai.inference.modeling.layers.attention import copy_to_cache
from colossalai.kernel.triton import copy_kv_to_blocked_cache
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
@ -108,69 +107,7 @@ def test_copy_kv_to_caches(
assert torch.equal(k_target, k_source)
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source)
# target_torch = k_cache_copy[target_block_ids, :, offsets_in_block, :]
# assert target_torch.shape == source.shape
# assert torch.equal(target_torch, source)
BATCH = 16
BLOCK_SIZE = 32
SAME_LEN = True
WARM_UPS = 10
REPS = 100
configs = [
triton.testing.Benchmark(
x_names=["KV_SEQ_LEN"],
x_vals=[2**i for i in range(8, 13)],
line_arg="provider",
line_vals=["torch_copy_func", "triton_copy_func"],
line_names=["torch_copy_func", "triton_copy_func"],
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name=f"kvcache_copy_decoding_stage-batch-{BATCH}",
args={"bsz": BATCH, "block_size": 16, "max_seq_len": 8192, "num_kv_heads": 16, "same_context_len": True},
)
]
@triton.testing.perf_report(configs)
def benchmark_kvcache_copy(
provider: str,
bsz: int,
block_size: int,
max_seq_len: int,
KV_SEQ_LEN: int, # maximum past kv length (unequal context lens in batch) or past kv len (equal context lens)
num_kv_heads: int,
same_context_len: bool,
):
dtype = torch.float16
device = get_current_device()
assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"
new_k, new_v, k_cache, v_cache, context_lengths, block_tables = prepare_data(
bsz,
num_kv_heads,
HEAD_DIM,
block_size,
max_seq_len // block_size,
same_context_len,
KV_SEQ_LEN,
device=device,
dtype=dtype,
)
quantiles = [0.5, 0.2, 0.8]
# TODO copy_to_cache needs to support copying both k and v at the same time in the future.
if provider == "torch_copy_func":
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
if provider == "triton_copy_func":
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
return ms, min_ms, max_ms
if __name__ == "__main__":
test_copy_kv_to_caches(4, 32, 8, 16, True)
# benchmark_kvcache_copy.run(save_path=".", print_data=True)