diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 5fa1e7161..876fed456 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -13,6 +13,7 @@ from transformers.models.llama.modeling_llama import ( from colossalai.inference.batch_bucket import BatchBucket from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import ( context_attention_unpadded, decoding_fused_rotary_embedding, @@ -22,6 +23,8 @@ from colossalai.kernel.triton import ( ) from colossalai.logging import get_dist_logger +inference_ops = InferenceOpsLoader().load() + logger = get_dist_logger(__name__) try: @@ -74,6 +77,12 @@ def llama_model_forward( sequence_lengths = batch.get_sequence_lengths() batch_size = batch.current_batch_size kv_seq_len = sequence_lengths.max().item() + use_cuda_kernel = True + # NOTE: After testing, the performance of this configuration is relatively good. With updates + # and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's + # selection should be conducted. + if batch_size >= 32 and kv_seq_len > 512: + use_cuda_kernel = False hidden_states = self.embed_tokens(input_ids) @@ -107,6 +116,7 @@ def llama_model_forward( output_tensor=output_tensor, norm_output=norm_output, sm_scale=sm_scale, + use_cuda_kernel=use_cuda_kernel, ) if batch.is_prompts: @@ -134,6 +144,7 @@ def llama_decoder_layer_forward( output_tensor: torch.Tensor = None, norm_output: torch.Tensor = None, sm_scale: int = None, + use_cuda_kernel: bool = True, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """This function will replace the forward function of LlamaDecoderLayer. @@ -153,6 +164,7 @@ def llama_decoder_layer_forward( output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None. sm_scale (int, optional): Used for flash attention. Defaults to None. + use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True. """ hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual) @@ -169,6 +181,7 @@ def llama_decoder_layer_forward( fd_inter_tensor=fd_inter_tensor, output_tensor=output_tensor, sm_scale=sm_scale, + use_cuda_kernel=use_cuda_kernel, ) # Fully Connected @@ -252,6 +265,7 @@ class NopadLlamaAttention(LlamaAttention): fd_inter_tensor: FDIntermTensors = None, output_tensor: torch.Tensor = None, sm_scale: int = None, + use_cuda_kernel: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Args: @@ -268,6 +282,7 @@ class NopadLlamaAttention(LlamaAttention): storing intermediate values in flash-decoding. Defaults to None. output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. sm_scale (int, optional): Used for flash attention. Defaults to None. + use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True. """ if self.num_heads != self.num_key_value_heads: @@ -283,7 +298,6 @@ class NopadLlamaAttention(LlamaAttention): ) block_size = k_cache.size(-2) - if is_prompts: rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) attn_output = context_attention_unpadded( @@ -300,17 +314,23 @@ class NopadLlamaAttention(LlamaAttention): sm_scale=sm_scale, ) else: - decoding_fused_rotary_embedding( - query_states, - key_states, - value_states, - cos_sin[0], - cos_sin[1], - k_cache, - v_cache, - block_tables, - sequence_lengths, - ) + if use_cuda_kernel: + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + inference_ops.decode_kv_cache_memcpy( + key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables + ) + else: + decoding_fused_rotary_embedding( + query_states, + key_states, + value_states, + cos_sin[0], + cos_sin[1], + k_cache, + v_cache, + block_tables, + sequence_lengths, + ) attn_output = flash_decoding_attention( q=query_states, k_cache=k_cache, diff --git a/colossalai/kernel/kernel_loader.py b/colossalai/kernel/kernel_loader.py index 148c3e3fc..f13e6223f 100644 --- a/colossalai/kernel/kernel_loader.py +++ b/colossalai/kernel/kernel_loader.py @@ -8,6 +8,7 @@ from .extensions import ( FlashAttentionNpuExtension, FlashAttentionXformersCudaExtension, FusedOptimizerCudaExtension, + InferenceOpsCudaExtension, LayerNormCudaExtension, MoeCudaExtension, ScaledMaskedSoftmaxCudaExtension, @@ -21,6 +22,7 @@ __all__ = [ "LayerNormLoader", "MoeLoader", "FusedOptimizerLoader", + "InferenceOpsLoader", "ScaledMaskedSoftmaxLoader", "ScaledUpperTriangleMaskedSoftmaxLoader", ] @@ -97,6 +99,10 @@ class FusedOptimizerLoader(KernelLoader): REGISTRY = [FusedOptimizerCudaExtension] +class InferenceOpsLoader(KernelLoader): + REGISTRY = [InferenceOpsCudaExtension] + + class ScaledMaskedSoftmaxLoader(KernelLoader): REGISTRY = [ScaledMaskedSoftmaxCudaExtension] diff --git a/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py b/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py new file mode 100644 index 000000000..de334e1f7 --- /dev/null +++ b/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py @@ -0,0 +1,80 @@ +import torch + +from colossalai.inference.modeling.layers.attention import copy_to_cache +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.kernel.triton import copy_kv_to_blocked_cache +from colossalai.utils import get_current_device +from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data + +try: + import triton # noqa +except ImportError: + print("please install triton from https://github.com/openai/triton") + +inference_ops = InferenceOpsLoader().load() + +HEAD_DIM = 4 +BATCH = 16 +BLOCK_SIZE = 32 +SAME_LEN = True +WARM_UPS = 10 +REPS = 100 +configs = [ + triton.testing.Benchmark( + x_names=["KV_SEQ_LEN"], + x_vals=[2**i for i in range(8, 13)], + line_arg="provider", + line_vals=["torch_copy_func", "triton_copy_func", "cuda_copy_func"], + line_names=["torch_copy_func", "triton_copy_func", "cuda_copy_func"], + styles=[("red", "-"), ("blue", "-"), ("green", "-")], + ylabel="ms", + plot_name=f"kvcache_copy_decoding_stage-batch-{BATCH}", + args={"bsz": BATCH, "block_size": 16, "max_seq_len": 8192, "num_kv_heads": 16, "same_context_len": True}, + ) +] + + +@triton.testing.perf_report(configs) +def benchmark_kvcache_copy( + provider: str, + bsz: int, + block_size: int, + max_seq_len: int, + KV_SEQ_LEN: int, # maximum past kv length (unequal context lens in batch) or past kv len (equal context lens) + num_kv_heads: int, + same_context_len: bool, +): + dtype = torch.float32 + device = get_current_device() + + assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len" + + new_k, new_v, k_cache, v_cache, context_lengths, block_tables = prepare_data( + bsz, + num_kv_heads, + HEAD_DIM, + block_size, + max_seq_len // block_size, + same_context_len, + KV_SEQ_LEN, + device=device, + dtype=dtype, + ) + + quantiles = [0.5, 0.2, 0.8] + # TODO copy_to_cache needs to support copying both k and v at the same time in the future. + if provider == "torch_copy_func": + fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding") + elif provider == "triton_copy_func": + fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables) + elif provider == "cuda_copy_func": + new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k + new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v + fn = lambda: inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, context_lengths, block_tables) + + ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) + return ms, min_ms, max_ms + + +if __name__ == "__main__": + benchmark_kvcache_copy.run(save_path=".", print_data=True) diff --git a/extensions/__init__.py b/extensions/__init__.py index 9343cadda..c3da1552a 100644 --- a/extensions/__init__.py +++ b/extensions/__init__.py @@ -4,6 +4,7 @@ from .flash_attention import ( FlashAttentionNpuExtension, FlashAttentionXformersCudaExtension, ) +from .inference import InferenceOpsCudaExtension from .layernorm import LayerNormCudaExtension from .moe import MoeCudaExtension from .optimizer import FusedOptimizerCudaExtension @@ -15,6 +16,7 @@ ALL_EXTENSIONS = [ LayerNormCudaExtension, MoeCudaExtension, FusedOptimizerCudaExtension, + InferenceOpsCudaExtension, ScaledMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension, FlashAttentionDaoCudaExtension, @@ -28,6 +30,7 @@ __all__ = [ "LayerNormCudaExtension", "MoeCudaExtension", "FusedOptimizerCudaExtension", + "InferenceOpsCudaExtension", "ScaledMaskedSoftmaxCudaExtension", "ScaledUpperTriangleMaskedSoftmaxCudaExtension", "FlashAttentionDaoCudaExtension", diff --git a/extensions/csrc/cuda/colossal_inference_C_frontend.cpp b/extensions/csrc/cuda/colossal_inference_C_frontend.cpp new file mode 100644 index 000000000..ae410c14f --- /dev/null +++ b/extensions/csrc/cuda/colossal_inference_C_frontend.cpp @@ -0,0 +1,15 @@ +#include + +void decode_kv_cache_memcpy( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size] + torch::Tensor& + value_cache, // [num_blocks, num_heads, block_size, head_size] + torch::Tensor& sequence_lengths, // [batch_size] + torch::Tensor& block_tables); // [batch_size, max_seq_len] + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy, + "Copy the GPU memory of kvcache during the decode stage."); +} diff --git a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu new file mode 100644 index 000000000..86db90c8b --- /dev/null +++ b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu @@ -0,0 +1,90 @@ +#include +#include +#include + +#include "type_shim.h" + +template +__global__ void decode_kv_cache_memcpy_kernel( + const scalar_t* __restrict__ key, + const scalar_t* __restrict__ value, + scalar_t* __restrict__ key_cache, + scalar_t* __restrict__ value_cache, + const int* __restrict__ sequence_lengths, + const int* __restrict__ block_tables, + const int num_heads, + const int head_size, + const int block_size, + const int key_stride, + const int value_stride, + const int block_table_stride +) +{ + const int seq_id = blockIdx.x; + const int seq_len = sequence_lengths[seq_id] - 1; + const int seq_id_in_block_table = seq_len / block_size; + const int block_offset = seq_len % block_size; + const int block_id = block_tables[seq_id * block_table_stride + seq_id_in_block_table]; + const int hidden_size = num_heads * head_size; + + if ( block_id < 0 ) { + return ; + } + + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + const int head_id = i / head_size; + const int head_offset = i % head_size; + const int key_src_id = seq_id * key_stride + i; + const int value_src_id = seq_id * value_stride + i; + const int target_src_id = block_id * hidden_size * block_size + + head_id * block_size * head_size + + block_offset * head_size + head_offset; + + key_cache[target_src_id] = key[key_src_id]; + value_cache[target_src_id] = value[value_src_id]; + } + +} + +void decode_kv_cache_memcpy( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size] + torch::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_size] + torch::Tensor& sequence_lengths, // [batch_size] + torch::Tensor& block_tables) // [batch_size, max_seq_len] +{ + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(2); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + int block_table_stride = block_tables.stride(0); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + DISPATCH_FLOAT_HALF_AND_BFLOAT( + key.scalar_type(), + "decode_kv_cache_memcpy", + decode_kv_cache_memcpy_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + block_tables.data_ptr(), + num_heads, + head_size, + block_size, + key_stride, + value_stride, + block_table_stride + );) + + AT_CUDA_CHECK(cudaGetLastError()); + +} diff --git a/extensions/csrc/cuda/type_shim.h b/extensions/csrc/cuda/type_shim.h index 03ccc0263..511631935 100644 --- a/extensions/csrc/cuda/type_shim.h +++ b/extensions/csrc/cuda/type_shim.h @@ -24,6 +24,27 @@ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } +#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ switch (TYPEIN) { \ case at::ScalarType::Float: { \ diff --git a/extensions/cuda_extension.py b/extensions/cuda_extension.py index b5e8a285b..842cd9713 100644 --- a/extensions/cuda_extension.py +++ b/extensions/cuda_extension.py @@ -1,7 +1,10 @@ import os +import time from abc import abstractmethod +from pathlib import Path from typing import List +from .base_extension import _Extension from .cpp_extension import _CppExtension from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list diff --git a/extensions/inference/__init__.py b/extensions/inference/__init__.py new file mode 100644 index 000000000..c5ea424fa --- /dev/null +++ b/extensions/inference/__init__.py @@ -0,0 +1,3 @@ +from .inference_ops_cuda import InferenceOpsCudaExtension + +__all__ = ["InferenceOpsCudaExtension"] diff --git a/extensions/inference/inference_ops_cuda.py b/extensions/inference/inference_ops_cuda.py new file mode 100644 index 000000000..12bec6fab --- /dev/null +++ b/extensions/inference/inference_ops_cuda.py @@ -0,0 +1,30 @@ +from ..cuda_extension import _CudaExtension +from ..utils import get_cuda_cc_flag + + +class InferenceOpsCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="inference_ops_cuda") + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "cuda/colossal_inference_C_frontend.cpp", + "cuda/decode_kv_cache_memcpy_kernel.cu", + ] + ] + return ret + + def include_dirs(self): + ret = [self.get_cuda_home_include()] + return ret + + def cxx_flags(self): + version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + return ["-O3"] + version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ["-lineinfo"] + extra_cuda_flags.extend(get_cuda_cc_flag()) + return ["-O3", "--use_fast_math"] + extra_cuda_flags diff --git a/tests/test_infer/test_ops/__init__.py b/tests/test_infer/test_ops/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_infer/test_ops/cuda/__init__.py b/tests/test_infer/test_ops/cuda/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py b/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py new file mode 100644 index 000000000..d5259a596 --- /dev/null +++ b/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py @@ -0,0 +1,65 @@ +import pytest +import torch + +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.utils import get_current_device +from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data + +inference_ops = InferenceOpsLoader().load() + +HEAD_DIM = 4 + + +@pytest.mark.parametrize("bsz", [4, 7, 32]) +@pytest.mark.parametrize("block_size", [16, 32, 64]) +@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32]) +@pytest.mark.parametrize("num_kv_heads", [16]) +@pytest.mark.parametrize("same_context_len", [True, False]) +def test_copy_kv_to_caches( + bsz: int, + block_size: int, + max_num_blocks_per_seq: int, + num_kv_heads: int, + same_context_len: bool, +): + torch.manual_seed(123) + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + max_seq_len = block_size * max_num_blocks_per_seq + dtype = torch.float32 + device = get_current_device() + + new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data( + bsz, + num_kv_heads, + HEAD_DIM, + block_size, + max_num_blocks_per_seq, + same_context_len, + max_seq_len, + device=device, + dtype=dtype, + ) + + new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k + new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v + inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables) + + past_kv_seq_len = kv_seq_lengths - 1 + target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size] + offsets_in_block = past_kv_seq_len % block_size + k_target = k_cache[target_block_ids, :, offsets_in_block, :] + k_source = new_k.squeeze() + v_target = v_cache[target_block_ids, :, offsets_in_block, :] + v_source = new_v.squeeze() + + assert k_target.shape == k_source.shape + assert torch.equal(k_target, k_source) + assert v_target.shape == v_source.shape + assert torch.equal(v_target, v_source) + + +if __name__ == "__main__": + test_copy_kv_to_caches(4, 32, 8, 16, True) diff --git a/tests/test_infer/test_ops/triton/__init__.py b/tests/test_infer/test_ops/triton/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_infer/test_ops/triton/test_kvcache_copy.py b/tests/test_infer/test_ops/triton/test_kvcache_copy.py index 53475270e..b3fdd4b88 100644 --- a/tests/test_infer/test_ops/triton/test_kvcache_copy.py +++ b/tests/test_infer/test_ops/triton/test_kvcache_copy.py @@ -2,7 +2,6 @@ import pytest import torch from packaging import version -from colossalai.inference.modeling.layers.attention import copy_to_cache from colossalai.kernel.triton import copy_kv_to_blocked_cache from colossalai.utils import get_current_device from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token @@ -108,69 +107,7 @@ def test_copy_kv_to_caches( assert torch.equal(k_target, k_source) assert v_target.shape == v_source.shape assert torch.equal(v_target, v_source) - # target_torch = k_cache_copy[target_block_ids, :, offsets_in_block, :] - # assert target_torch.shape == source.shape - # assert torch.equal(target_torch, source) - - -BATCH = 16 -BLOCK_SIZE = 32 -SAME_LEN = True -WARM_UPS = 10 -REPS = 100 -configs = [ - triton.testing.Benchmark( - x_names=["KV_SEQ_LEN"], - x_vals=[2**i for i in range(8, 13)], - line_arg="provider", - line_vals=["torch_copy_func", "triton_copy_func"], - line_names=["torch_copy_func", "triton_copy_func"], - styles=[("red", "-"), ("blue", "-")], - ylabel="ms", - plot_name=f"kvcache_copy_decoding_stage-batch-{BATCH}", - args={"bsz": BATCH, "block_size": 16, "max_seq_len": 8192, "num_kv_heads": 16, "same_context_len": True}, - ) -] - - -@triton.testing.perf_report(configs) -def benchmark_kvcache_copy( - provider: str, - bsz: int, - block_size: int, - max_seq_len: int, - KV_SEQ_LEN: int, # maximum past kv length (unequal context lens in batch) or past kv len (equal context lens) - num_kv_heads: int, - same_context_len: bool, -): - dtype = torch.float16 - device = get_current_device() - - assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len" - - new_k, new_v, k_cache, v_cache, context_lengths, block_tables = prepare_data( - bsz, - num_kv_heads, - HEAD_DIM, - block_size, - max_seq_len // block_size, - same_context_len, - KV_SEQ_LEN, - device=device, - dtype=dtype, - ) - - quantiles = [0.5, 0.2, 0.8] - # TODO copy_to_cache needs to support copying both k and v at the same time in the future. - if provider == "torch_copy_func": - fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding") - if provider == "triton_copy_func": - fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables) - - ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) - return ms, min_ms, max_ms if __name__ == "__main__": test_copy_kv_to_caches(4, 32, 8, 16, True) - # benchmark_kvcache_copy.run(save_path=".", print_data=True)