From 5cd75ce4c7edc95bacd8ec5fc04b8add339e8331 Mon Sep 17 00:00:00 2001 From: Steve Luo <36296769+SunflowerAries@users.noreply.github.com> Date: Tue, 30 Apr 2024 15:52:23 +0800 Subject: [PATCH] =?UTF-8?q?[Inference/Kernel]=20refactor=20kvcache=20manag?= =?UTF-8?q?er=20and=20rotary=5Fembedding=20and=20kvcache=5Fmemcpy=20oper?= =?UTF-8?q?=E2=80=A6=20(#5663)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor kvcache manager and rotary_embedding and kvcache_memcpy operator * refactor decode_kv_cache_memcpy * enable alibi in pagedattention --- .../inference/kv_cache/kvcache_manager.py | 23 ++- .../modeling/models/nopadding_baichuan.py | 46 ++++-- .../modeling/models/nopadding_llama.py | 67 ++++---- .../benchmark_flash_decoding_attention.py | 6 +- .../benchmark_fused_rotary_embdding_unpad.py | 18 ++- .../benchmark_kv_cache_memcopy.py | 4 + .../cuda/context_kv_cache_memcpy_kernel.cu | 46 ++++-- .../cuda/decode_kv_cache_memcpy_kernel.cu | 39 +++-- .../cuda/flash_decoding_attention_kernel.cu | 15 +- .../cuda/fused_rotary_emb_and_cache_kernel.cu | 147 +++++++----------- extensions/pybind/inference/inference.cpp | 28 ++-- .../cuda/test_flash_decoding_attention.py | 49 +++++- .../test_ops/cuda/test_kv_cache_memcpy.py | 100 ++++++++---- .../cuda/test_rotary_embdding_unpad.py | 15 +- 14 files changed, 368 insertions(+), 235 deletions(-) diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 8b9605a52..50546271e 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -90,9 +90,18 @@ class KVCacheManager: self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width # Physical cache allocation - alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size) - self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.") - self._kv_caches = self._init_device_caches(alloc_shape) + if config.use_cuda_kernel: + x = 16 // torch.tensor([], dtype=config.dtype).element_size() + kalloc_shape = (self.num_blocks, self.kv_head_num, self.head_size // x, self.block_size, x) + valloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size) + self.logger.info( + f"Allocating K cache with shape: {kalloc_shape}, V cache with shape: {valloc_shape} consisting of {self.num_blocks} blocks." + ) + self._kv_caches = self._init_device_caches(kalloc_shape, valloc_shape) + else: + alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size) + self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.") + self._kv_caches = self._init_device_caches(alloc_shape, alloc_shape) self.total_physical_cache_size_in_bytes = ( self.elem_size_in_bytes * self.num_layers @@ -479,7 +488,9 @@ class KVCacheManager: blocks.append(cache_block) return blocks - def _init_device_caches(self, alloc_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor]: + def _init_device_caches( + self, kalloc_shape: Tuple[int, ...], valloc_shape: Tuple[int, ...] + ) -> Tuple[torch.Tensor, torch.Tensor]: """Initialize the physical cache on the device. For each layer of the model, we allocate two tensors for key and value respectively, @@ -488,6 +499,6 @@ class KVCacheManager: k_cache: List[torch.Tensor] = [] v_cache: List[torch.Tensor] = [] for _ in range(self.num_layers): - k_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device)) - v_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device)) + k_cache.append(torch.zeros(kalloc_shape, dtype=self.dtype, device=self.device)) + v_cache.append(torch.zeros(valloc_shape, dtype=self.dtype, device=self.device)) return k_cache, v_cache diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py index 441d941e1..ca8a0e696 100644 --- a/colossalai/inference/modeling/models/nopadding_baichuan.py +++ b/colossalai/inference/modeling/models/nopadding_baichuan.py @@ -310,6 +310,7 @@ class NopadBaichuanAttention(ParallelModule): alibi_slopes=self.alibi_slopes, max_seq_len=kv_seq_len, sm_scale=sm_scale, + use_new_kcache_layout=use_cuda_kernel, ) else: q_len = tokens_to_verify + 1 if is_verifier else 1 @@ -332,6 +333,21 @@ class NopadBaichuanAttention(ParallelModule): inference_ops.decode_kv_cache_memcpy( key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables ) + inference_ops.flash_decoding_attention( + output_tensor, + query_states, + k_cache, + v_cache, + sequence_lengths, + block_tables, + block_size, + kv_seq_len, + fd_inter_tensor.mid_output, + fd_inter_tensor.mid_output_lse, + self.alibi_slopes, + sm_scale, + ) + attn_output = output_tensor else: if not is_verifier and not self.use_alibi_attn: decoding_fused_rotary_embedding( @@ -355,21 +371,21 @@ class NopadBaichuanAttention(ParallelModule): value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len ) - attn_output = flash_decoding_attention( - q=query_states, - k_cache=k_cache, - v_cache=v_cache, - kv_seq_len=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - max_seq_len_in_batch=kv_seq_len, - output=output_tensor, - mid_output=fd_inter_tensor.mid_output, - mid_output_lse=fd_inter_tensor.mid_output_lse, - alibi_slopes=self.alibi_slopes, - sm_scale=sm_scale, - q_len=q_len, - ) + attn_output = flash_decoding_attention( + q=query_states, + k_cache=k_cache, + v_cache=v_cache, + kv_seq_len=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + max_seq_len_in_batch=kv_seq_len, + output=output_tensor, + mid_output=fd_inter_tensor.mid_output, + mid_output_lse=fd_inter_tensor.mid_output_lse, + alibi_slopes=self.alibi_slopes, + sm_scale=sm_scale, + q_len=q_len, + ) attn_output = attn_output.view(-1, self.hidden_size) attn_output = self.o_proj(attn_output) diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 8249eafcf..557ca0d12 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -98,15 +98,8 @@ def llama_model_forward( """ block_tables = inputmetadata.block_tables sequence_lengths = inputmetadata.sequence_lengths - batch_size = inputmetadata.batch_size kv_seq_len = inputmetadata.kv_seq_len - # NOTE: After testing, the performance of this configuration is relatively good. With updates - # and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's - # selection should be conducted. - if batch_size >= 32 and kv_seq_len > 512: - use_cuda_kernel = False - # NOTE (yuanheng-zhao): fow now, only triton kernels support verification process # during speculative-decoding (`q_len > 1`) # We will expicitly disable `use_cuda_kernel` here when speculative-decoding is enabled @@ -575,6 +568,7 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention): output=output_tensor, max_seq_len=kv_seq_len, sm_scale=sm_scale, + use_new_kcache_layout=use_cuda_kernel, ) else: q_len = tokens_to_verify + 1 if is_verifier else 1 @@ -592,20 +586,21 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention): block_tables, high_precision, ) - # inference_ops.flash_decoding_attention( - # output_tensor, - # query_states, - # k_cache, - # v_cache, - # sequence_lengths, - # block_tables, - # block_size, - # kv_seq_len, - # fd_inter_tensor.mid_output, - # fd_inter_tensor.mid_output_lse, - # sm_scale, - # ) - # attn_output = output_tensor + inference_ops.flash_decoding_attention( + output_tensor, + query_states, + k_cache, + v_cache, + sequence_lengths, + block_tables, + block_size, + kv_seq_len, + fd_inter_tensor.mid_output, + fd_inter_tensor.mid_output_lse, + None, + sm_scale, + ) + attn_output = output_tensor else: if is_verifier: rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) @@ -627,21 +622,21 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention): block_tables, sequence_lengths, ) - attn_output = flash_decoding_attention( - q=query_states, - k_cache=k_cache, - v_cache=v_cache, - kv_seq_len=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - max_seq_len_in_batch=kv_seq_len, - output=output_tensor, - mid_output=fd_inter_tensor.mid_output, - mid_output_lse=fd_inter_tensor.mid_output_lse, - sm_scale=sm_scale, - kv_group_num=self.num_key_value_groups, - q_len=q_len, - ) + attn_output = flash_decoding_attention( + q=query_states, + k_cache=k_cache, + v_cache=v_cache, + kv_seq_len=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + max_seq_len_in_batch=kv_seq_len, + output=output_tensor, + mid_output=fd_inter_tensor.mid_output, + mid_output_lse=fd_inter_tensor.mid_output_lse, + sm_scale=sm_scale, + kv_group_num=self.num_key_value_groups, + q_len=q_len, + ) attn_output = attn_output.view(-1, self.hidden_size) attn_output = self.o_proj(attn_output) diff --git a/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py b/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py index 1a18ffa2e..35eae69b6 100644 --- a/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py +++ b/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py @@ -20,7 +20,7 @@ inference_ops = InferenceOpsLoader().load() configs = [ triton.testing.Benchmark( x_names=["MAX_NUM_BLOCKS_PER_SEQ"], - x_vals=[2**i for i in range(3, 8)], + x_vals=[2**i for i in range(2, 8)], line_arg="provider", line_vals=[ "vllm_paged_decoding_attention", @@ -113,6 +113,8 @@ def benchmark_flash_decoding_attention( kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device) sm_scale = 1.0 / (HEAD_SIZE**0.5) + alibi_slopes = None + kv_scale = 1.0 mid_output = torch.empty( size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device @@ -136,6 +138,7 @@ def benchmark_flash_decoding_attention( max_seq_len_across_batch, alibi_slopes, "auto", + kv_scale, ) elif provider == "triton_flash_decoding_attention": fn = lambda: flash_decoding_attention( @@ -164,6 +167,7 @@ def benchmark_flash_decoding_attention( max_seq_len_across_batch, mid_output, mid_output_lse, + alibi_slopes, sm_scale, ) else: diff --git a/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py b/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py index f11630dff..9c9fdcebd 100644 --- a/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py +++ b/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py @@ -2,7 +2,11 @@ import torch from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding -from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token +from tests.test_infer.test_ops.triton.kernel_utils import ( + mock_alloc_block_table_and_kvcache_v2, + mock_alloc_block_table_and_kvcache_v3, + mock_alloc_single_token, +) inference_ops = InferenceOpsLoader().load() @@ -68,11 +72,17 @@ def benchmark_rotary_emb( cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim) k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda") v_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda") + x = 16 // torch.tensor([], dtype=dtype).element_size() + new_cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x) + new_k_cache = torch.zeros(size=new_cache_shape, dtype=dtype, device="cuda") past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda") block_tables = mock_alloc_block_table_and_kvcache_v2( k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size ) + _ = mock_alloc_block_table_and_kvcache_v3( + k, v, new_k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size + ) new_k = torch.randn((BATCH_SIZE, num_kv_heads, head_dim), dtype=dtype, device="cuda") new_q = torch.randn_like(new_k) new_v = torch.randn_like(new_k) @@ -94,12 +104,12 @@ def benchmark_rotary_emb( ) elif provider == "no_fused_cuda_rotary_emb_func": fn = lambda: [ - inference_ops.rotary_embedding(new_q, new_k, cos, sin), - inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables), + inference_ops.rotary_embedding(new_q, new_k, cos, sin, True), + inference_ops.decode_kv_cache_memcpy(new_k, new_v, new_k_cache, v_cache, kv_seq_lengths, block_tables), ] elif provider == "fused_cuda_rotary_emb_func": fn = lambda: inference_ops.rotary_embedding_and_cache_copy( - new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables + new_q, new_k, new_v, cos, sin, new_k_cache, v_cache, kv_seq_lengths, block_tables, True ) else: raise ValueError("Undefined provider") diff --git a/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py b/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py index de334e1f7..8121eba59 100644 --- a/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py +++ b/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py @@ -4,6 +4,7 @@ from colossalai.inference.modeling.layers.attention import copy_to_cache from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import copy_kv_to_blocked_cache from colossalai.utils import get_current_device +from tests.test_infer.test_ops.cuda.test_kv_cache_memcpy import prepare_data as prepare_data_new_kcache_layout from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data try: @@ -68,6 +69,9 @@ def benchmark_kvcache_copy( elif provider == "triton_copy_func": fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables) elif provider == "cuda_copy_func": + _, _, k_cache, _, _, _, _, _, _ = prepare_data_new_kcache_layout( + bsz, num_kv_heads, block_size, max_seq_len // block_size, context_lengths - 1, device, dtype + ) new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v fn = lambda: inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, context_lengths, block_tables) diff --git a/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu b/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu index 6e849b074..473324f45 100644 --- a/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu @@ -24,14 +24,15 @@ __global__ void context_kv_cache_memcpy_kernel( const int batch_size, const int block_table_stride, const int64_t key_stride, - const int64_t value_stride + const int64_t value_stride, + const int x ) { const int seq_token_id = blockIdx.x; const int seq_id = blockIdx.y; const int block_id = block_tables[seq_id * block_table_stride + seq_token_id / block_size]; - if ( block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) { + if (block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) { return ; } @@ -40,23 +41,33 @@ __global__ void context_kv_cache_memcpy_kernel( const int total_token_id = cu_seqlens[seq_id] + seq_token_id; int head_id; int head_offset; + int x_id; + int x_offset; int64_t key_src_id; int64_t value_src_id; - int64_t target_id; + int64_t target_key_id; + int64_t target_value_id; int i = threadIdx.x * VecSize; for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) { head_id = i / head_dim; head_offset = i % head_dim; + x_id = head_offset / x; + x_offset = head_offset % x; key_src_id = total_token_id * key_stride + i; value_src_id = total_token_id * value_stride + i; - target_id = block_id * hidden_size * block_size + target_key_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + x_id * block_size * x + + block_offset * x + + x_offset; + target_value_id = block_id * hidden_size * block_size + head_id * block_size * head_dim + block_offset * head_dim + head_offset; - copy(key + key_src_id, key_cache + target_id); - copy(value + value_src_id, value_cache + target_id); + copy(key + key_src_id, key_cache + target_key_id); + copy(value + value_src_id, value_cache + target_value_id); } // tail process @@ -64,14 +75,21 @@ __global__ void context_kv_cache_memcpy_kernel( for (; i < hidden_size; ++i ) { head_id = i / head_dim; head_offset = i % head_dim; + x_id = head_offset / x; + x_offset = head_offset % x; key_src_id = total_token_id * key_stride + i; value_src_id = total_token_id * value_stride + i; - target_id = block_id * hidden_size * block_size + target_key_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + x_id * block_size * x + + block_offset * x + + x_offset; + target_value_id = block_id * hidden_size * block_size + head_id * block_size * head_dim + block_offset * head_dim + head_offset; - key_cache[target_id] = CastFunctor()(key[key_src_id]); - value_cache[target_id] = CastFunctor()(value[value_src_id]); + key_cache[target_key_id] = CastFunctor()(key[key_src_id]); + value_cache[target_value_id] = CastFunctor()(value[value_src_id]); } } @@ -81,7 +99,7 @@ template void apply_context_kv_cache_memcpy( torch::Tensor& key, // [num_tokens, head_num, head_dim] torch::Tensor& value, // [num_tokens, head_num, head_dim] - torch::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + torch::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] torch::Tensor& sequence_lengths, // [batch_size] torch::Tensor& cu_seqlens, // [batch_size + 1] @@ -91,7 +109,8 @@ void apply_context_kv_cache_memcpy( int num_tokens = key.size(0); int head_num = key.size(1); int head_dim = key.size(2); - int block_size = key_cache.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); int batch_size = block_tables.size(0); int64_t key_stride = key.stride(0); @@ -127,7 +146,8 @@ void apply_context_kv_cache_memcpy( batch_size, \ block_table_stride, \ key_stride, \ - value_stride \ + value_stride, \ + x \ ); \ } while(0) @@ -164,7 +184,7 @@ void apply_context_kv_cache_memcpy( void context_kv_cache_memcpy( torch::Tensor& key, // [num_tokens, head_num, head_dim] torch::Tensor& value, // [num_tokens, head_num, head_dim] - torch::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + torch::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] torch::Tensor& sequence_lengths, // [batch_size] torch::Tensor& cu_seqlens, // [batch_size + 1] diff --git a/extensions/csrc/kernel/cuda/decode_kv_cache_memcpy_kernel.cu b/extensions/csrc/kernel/cuda/decode_kv_cache_memcpy_kernel.cu index f29379f5c..03682187e 100644 --- a/extensions/csrc/kernel/cuda/decode_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/kernel/cuda/decode_kv_cache_memcpy_kernel.cu @@ -20,7 +20,8 @@ __global__ void decode_kv_cache_memcpy_kernel( const int block_size, const int64_t key_stride, const int64_t value_stride, - const int block_table_stride + const int block_table_stride, + const int x ) { const int seq_id = blockIdx.x; @@ -38,28 +39,42 @@ __global__ void decode_kv_cache_memcpy_kernel( for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) { const int head_id = i / head_dim; const int head_offset = i % head_dim; + const int x_id = head_offset / x; + const int x_offset = head_offset % x; const int64_t key_src_id = seq_id * key_stride + i; const int64_t value_src_id = seq_id * value_stride + i; - const int64_t target_id = block_id * hidden_size * block_size + const int64_t target_key_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + x_id * block_size * x + + block_offset * x + + x_offset; + const int64_t target_value_id = block_id * hidden_size * block_size + head_id * block_size * head_dim + block_offset * head_dim + head_offset; - copy_vector(key_cache + target_id, key + key_src_id); - copy_vector(value_cache + target_id, value + value_src_id); + copy_vector(key_cache + target_key_id, key + key_src_id); + copy_vector(value_cache + target_value_id, value + value_src_id); } if (!Aligned) { for (; i < hidden_size; ++i ) { const int head_id = i / head_dim; const int head_offset = i % head_dim; + const int x_id = head_offset / x; + const int x_offset = head_offset % x; const int64_t key_src_id = seq_id * key_stride + i; const int64_t value_src_id = seq_id * value_stride + i; - const int64_t target_id = block_id * hidden_size * block_size + const int64_t target_key_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + x_id * block_size * x + + block_offset * x + + x_offset; + const int64_t target_value_id = block_id * hidden_size * block_size + head_id * block_size * head_dim + block_offset * head_dim + head_offset; - key_cache[target_id] = key[key_src_id]; - value_cache[target_id] = value[value_src_id]; + key_cache[target_key_id] = key[key_src_id]; + value_cache[target_value_id] = value[value_src_id]; } } @@ -69,7 +84,7 @@ template void apply_decode_kv_cache_memcpy( at::Tensor& key, // [num_tokens, head_num, head_dim] at::Tensor& value, // [num_tokens, head_num, head_dim] - at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x] at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& sequence_lengths, // [batch_size] at::Tensor& block_tables) // [batch_size, max_seq_len] @@ -77,7 +92,8 @@ void apply_decode_kv_cache_memcpy( int num_tokens = key.size(0); int head_num = key.size(1); int head_dim = key.size(2); - int block_size = key_cache.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); int64_t key_stride = key.stride(0); int64_t value_stride = value.stride(0); @@ -110,7 +126,8 @@ void apply_decode_kv_cache_memcpy( block_size, \ key_stride, \ value_stride, \ - block_table_stride \ + block_table_stride, \ + x \ ); \ } while(0) @@ -146,7 +163,7 @@ void apply_decode_kv_cache_memcpy( void decode_kv_cache_memcpy( at::Tensor& key, // [num_tokens, head_num, head_dim] at::Tensor& value, // [num_tokens, head_num, head_dim] - at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x] at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& sequence_lengths, // [batch_size] at::Tensor& block_tables) // [batch_size, max_seq_len] diff --git a/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu b/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu index ac5e40725..110907435 100644 --- a/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu +++ b/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu @@ -67,6 +67,7 @@ __global__ void flash_decoding_attention_kernel( const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size] const int* __restrict__ context_lens, // [num_tokens] const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq] + const float* __restrict__ alibi_slopes, // [num_heads] const int max_seq_len, const int num_kv_heads, const float scale, @@ -105,6 +106,7 @@ __global__ void flash_decoding_attention_kernel( using FloatVecT = typename FloatVecTypeTrait::Type; const int context_len = context_lens[seq_idx]; + const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; const int thread_group_offset = lane % NUM_THREADS_PER_X; const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; @@ -164,6 +166,7 @@ __global__ void flash_decoding_attention_kernel( if (thread_group_offset == 0 && lane < NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X) { const int token_idx = block_idx * BLOCK_SIZE + i * NUM_ROWS_PER_ROUNDS + lane / NUM_THREADS_PER_X; + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; const bool mask = token_idx >= context_len; logits[token_idx] = mask ? 0.f : qk; qk_max = mask ? qk_max : fmaxf(qk_max, qk); @@ -261,6 +264,7 @@ __global__ void flash_decoding_attention_kernel( reinterpret_cast(value_cache.data_ptr()), \ context_lens.data_ptr(), \ block_tables.data_ptr(), \ + alibi_slopes_ptr, \ max_context_len, \ num_kv_heads, \ scale, \ @@ -282,7 +286,8 @@ void flash_decoding_attention_v1_launcher( torch::Tensor& context_lens, // [num_tokens] torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq] int max_context_len, - float scale) { + float scale, + const c10::optional& alibi_slopes) { int num_tokens = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -304,6 +309,10 @@ void flash_decoding_attention_v1_launcher( // Keep that in sync with the logic here! int shared_mem_size = std::max(logits_size, outputs_size) + DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4); + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + dim3 grid(num_heads, num_tokens, 1); dim3 block(NUM_THREADS); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); @@ -336,7 +345,8 @@ void flash_decoding_attention_v1_launcher( context_lens, \ block_tables, \ max_context_len, \ - scale); + scale, \ + alibi_slopes); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. @@ -367,6 +377,7 @@ void flash_decoding_attention( int max_context_len, torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size] torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions] + const c10::optional& alibi_slopes, float scale) { diff --git a/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu b/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu index 52f3588a7..7a2629171 100644 --- a/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu +++ b/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu @@ -91,7 +91,7 @@ __device__ void apply_k_rotary_emb_compute( const int* __restrict__ block_tables, const int64_t key_stride, const int64_t value_stride, const int token_id, const int block_table_stride, const int head_num, const int head_dim, - const int kv_head_num, const int block_size, const int half_head_dim, + const int kv_head_num, const int block_size, const int x, const int half_head_dim, const int shard_block_size) { const int seq_len = sequence_lengths[token_id] - 1; const int block_offset = seq_len % block_size; @@ -102,36 +102,40 @@ __device__ void apply_k_rotary_emb_compute( return; } - scalar_t x[VecSize]; - scalar_t y[VecSize]; + scalar_t x0[VecSize]; + scalar_t x1[VecSize]; scalar_t out_x[VecSize]; scalar_t out_y[VecSize]; for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim; i += blockDim.x * VecSize) { - const int head_offset = i % half_head_dim; + const int half_head_offset = i % half_head_dim; + const int x_id = half_head_offset / x; + const int x_offset = half_head_offset % x; const int shard_offset = - (head_offset / shard_block_size) * shard_block_size + - (head_offset % shard_block_size) / VecSize; + (half_head_offset / shard_block_size) * shard_block_size + + (half_head_offset % shard_block_size) / VecSize; const int64_t addr_offset = - token_id * key_stride + (i / half_head_dim) * head_dim + head_offset; - const int64_t target_id = block_id * kv_head_num * head_dim * block_size + - (i / half_head_dim) * block_size * head_dim + - block_offset * head_dim + head_offset; + token_id * key_stride + (i / half_head_dim) * head_dim + half_head_offset; + const int64_t target_id = block_id * kv_head_num * head_dim * block_size + + (i / half_head_dim) * block_size * head_dim + + x_id * block_size * x + + block_offset * x + + x_offset; - copy_vector(x, key + addr_offset); - copy_vector(y, key + addr_offset + half_head_dim); + copy_vector(x0, key + addr_offset); + copy_vector(x1, key + addr_offset + half_head_dim); #pragma unroll for (int j = 0; j < VecSize; j++) { - out_x[j] = static_cast(static_cast(x[j]) * cos_ptr[j * 32 + shard_offset] - - static_cast(y[j]) * sin_ptr[j * 32 + shard_offset]); - out_y[j] = static_cast(static_cast(y[j]) * cos_ptr[j * 32 + shard_offset] + - static_cast(x[j]) * sin_ptr[j * 32 + shard_offset]); + out_x[j] = static_cast(static_cast(x0[j]) * cos_ptr[j * 32 + shard_offset] - + static_cast(x1[j]) * sin_ptr[j * 32 + shard_offset]); + out_y[j] = static_cast(static_cast(x1[j]) * cos_ptr[j * 32 + shard_offset] + + static_cast(x0[j]) * sin_ptr[j * 32 + shard_offset]); } copy_vector(key_cache + target_id, out_x); - copy_vector(key_cache + target_id + half_head_dim, + copy_vector(key_cache + target_id + half_head_dim * block_size, out_y); } @@ -162,7 +166,8 @@ __global__ void rotary_embedding_and_cache_copy_kernel( const int head_num, const int head_dim, const int kv_head_num, - const int block_size + const int block_size, + const int x ) { const int token_id = blockIdx.x; @@ -182,7 +187,7 @@ __global__ void rotary_embedding_and_cache_copy_kernel( apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); //compute key and copy kv - apply_k_rotary_emb_compute(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size); + apply_k_rotary_emb_compute(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, x, half_head_dim, shard_block_size); } template @@ -220,6 +225,31 @@ __global__ void rotary_embedding_kernel( apply_emb_rotary_compute(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim); } +#define ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(VEC_SIZE) \ + rotary_embedding_and_cache_copy_kernel<<>>( \ + query.data_ptr(), \ + key.data_ptr(), \ + value.data_ptr(), \ + cos.data_ptr(), \ + sin.data_ptr(), \ + key_cache.data_ptr(), \ + value_cache.data_ptr(), \ + sequence_lengths.data_ptr(), \ + block_tables.data_ptr(), \ + query_stride, \ + key_stride, \ + value_stride, \ + shard_element_num / 2, \ + cos_stride, \ + sin_stride, \ + block_table_stride, \ + head_num, \ + head_dim, \ + kv_head_num, \ + block_size, \ + x); \ + + template void apply_rotary_embedding_and_cache_copy( at::Tensor& query, // [num_tokens, head_num, head_dim] @@ -227,7 +257,7 @@ void apply_rotary_embedding_and_cache_copy( at::Tensor& value, // [num_tokens, kv_head_num, head_dim] at::Tensor& cos, // [num_tokens, head_dim] at::Tensor& sin, // [num_tokens, head_dim] - at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x] at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& sequence_lengths, // [batch_size] at::Tensor& block_tables) // [batch_size, max_seq_len] @@ -236,7 +266,8 @@ void apply_rotary_embedding_and_cache_copy( int head_num = query.size(1); int head_dim = query.size(2); int kv_head_num = key.size(1); - int block_size = key_cache.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); int64_t query_stride = query.stride(0); int64_t key_stride = key.stride(0); @@ -261,80 +292,18 @@ void apply_rotary_embedding_and_cache_copy( dim3 grid(num_tokens); dim3 block(std::min(thread_nums, 512)); - int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ; + int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size; + const int shared_memory_size = shard_element_num * sizeof(m_scalar_t); switch (vec_size) { case 1: - rotary_embedding_and_cache_copy_kernel<<>>( - query.data_ptr(), - key.data_ptr(), - value.data_ptr(), - cos.data_ptr(), - sin.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - sequence_lengths.data_ptr(), - block_tables.data_ptr(), - query_stride, - key_stride, - value_stride, - shard_element_num / 2, - cos_stride, - sin_stride, - block_table_stride, - head_num, - head_dim, - kv_head_num, - block_size - ); + ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(1); break; case 2: - rotary_embedding_and_cache_copy_kernel<<>>( - query.data_ptr(), - key.data_ptr(), - value.data_ptr(), - cos.data_ptr(), - sin.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - sequence_lengths.data_ptr(), - block_tables.data_ptr(), - query_stride, - key_stride, - value_stride, - shard_element_num / 2, - cos_stride, - sin_stride, - block_table_stride, - head_num, - head_dim, - kv_head_num, - block_size - ); + ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(2); break; case 4: - rotary_embedding_and_cache_copy_kernel<<>>( - query.data_ptr(), - key.data_ptr(), - value.data_ptr(), - cos.data_ptr(), - sin.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - sequence_lengths.data_ptr(), - block_tables.data_ptr(), - query_stride, - key_stride, - value_stride, - shard_element_num / 2, - cos_stride, - sin_stride, - block_table_stride, - head_num, - head_dim, - kv_head_num, - block_size - ); + ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(4); break; default: AT_ERROR("Unsupported vectorized size ", vec_size); @@ -441,7 +410,7 @@ void rotary_embedding_and_cache_copy( at::Tensor& value, // [num_tokens, kv_head_num, head_dim] at::Tensor& cos, // [num_tokens, head_dim] at::Tensor& sin, // [num_tokens, head_dim] - at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x] at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& sequence_lengths, // [batch_size] at::Tensor& block_tables, // [batch_size, max_seq_len] diff --git a/extensions/pybind/inference/inference.cpp b/extensions/pybind/inference/inference.cpp index 0604d4c71..e0fac00bd 100644 --- a/extensions/pybind/inference/inference.cpp +++ b/extensions/pybind/inference/inference.cpp @@ -1,18 +1,19 @@ #include void decode_kv_cache_memcpy( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size] + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, head_num, head_dim/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_size] torch::Tensor& sequence_lengths, // [batch_size] torch::Tensor& block_tables); // [batch_size, max_seq_len] void context_kv_cache_memcpy( - at::Tensor& key, // [num_tokens, head_num, head_dim] - at::Tensor& value, // [num_tokens, head_num, head_dim] - at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x] at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& sequence_lengths, // [batch_size] at::Tensor& cu_seqlens, // [batch_size + 1] @@ -27,12 +28,13 @@ void rotary_embedding( bool high_precision); void rotary_embedding_and_cache_copy( - torch::Tensor& query, // [num_tokens, head_num, head_dim] - torch::Tensor& key, // [num_tokens, kv_head_num, head_dim] - torch::Tensor& value, // [num_tokens, num_heads, head_dim] - torch::Tensor& cos, // [num_tokens, head_dim] - torch::Tensor& sin, // [num_tokens, head_dim] - torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_dim] + torch::Tensor& query, // [num_tokens, head_num, head_dim] + torch::Tensor& key, // [num_tokens, kv_head_num, head_dim] + torch::Tensor& value, // [num_tokens, num_heads, head_dim] + torch::Tensor& cos, // [num_tokens, head_dim] + torch::Tensor& sin, // [num_tokens, head_dim] + torch::Tensor& + key_cache, // [num_blocks, head_num, head_dim/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_dim] torch::Tensor& sequence_lengths, // [batch_size] @@ -71,7 +73,7 @@ void flash_decoding_attention( torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size] torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions] - float scale); + const c10::optional& alibi_slopes, float scale); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy, diff --git a/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py b/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py index 1a4d363a2..b3bd503bb 100644 --- a/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py +++ b/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py @@ -4,8 +4,10 @@ import numpy as np import pytest import torch +from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.utils import get_current_device +from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask inference_ops = InferenceOpsLoader().load() @@ -60,8 +62,9 @@ def numpy_allclose(x, y, rtol, atol): @pytest.mark.parametrize("NUM_ATTN_HEADS", [16]) @pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("use_alibi_slopes", [True, False]) def test_flash_decoding_attention( - BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype + BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes ): torch.manual_seed(123) torch.cuda.empty_cache() @@ -73,6 +76,11 @@ def test_flash_decoding_attention( MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ device = get_current_device() + if use_alibi_slopes: + alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device) + else: + alibi_slopes = None + q, k_unpad, v_unpad, kv_seq_lengths = prepare_data( BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device ) @@ -91,6 +99,15 @@ def test_flash_decoding_attention( v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch) torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device) + if use_alibi_slopes: + alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device) + torch_padding_mask = torch_padding_mask + alibi_mask + + if len(torch_padding_mask.size()) == 4: + torch_padding_mask = torch_padding_mask[:, :, -1:, :] + else: + torch_padding_mask = torch_padding_mask[:, -1:, :] + mid_output = torch.empty( size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device ) @@ -146,8 +163,14 @@ def test_flash_decoding_attention( max_seq_len_across_batch, mid_output, mid_output_lse, + alibi_slopes, sm_scale, ) + + # The alibi may introduce relatively large errors + if use_alibi_slopes: + rtol = 1e0 + numpy_allclose(out_ref, output, rtol=rtol, atol=atol) @@ -168,8 +191,9 @@ except ImportError: @pytest.mark.parametrize("NUM_ATTN_HEADS", [16]) @pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("use_alibi_slopes", [True, False]) def test_vllm_flash_decoding_attention( - BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype + BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes ): torch.manual_seed(123) torch.cuda.empty_cache() @@ -199,6 +223,18 @@ def test_vllm_flash_decoding_attention( v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch) torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device) + if use_alibi_slopes: + alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device) + alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device) + torch_padding_mask = torch_padding_mask + alibi_mask + + if len(torch_padding_mask.size()) == 4: + torch_padding_mask = torch_padding_mask[:, :, -1:, :] + else: + torch_padding_mask = torch_padding_mask[:, -1:, :] + else: + alibi_slopes = None + if dtype == torch.float16: rtol = 1e-3 atol = 1e-3 @@ -236,8 +272,6 @@ def test_vllm_flash_decoding_attention( HEAD_SIZE, ) - alibi_slopes = None - vllm_ops.paged_attention_v1( output, q.squeeze(2), @@ -253,6 +287,11 @@ def test_vllm_flash_decoding_attention( "auto", kv_scale, ) + + # The alibi may introduce relatively large errors + if use_alibi_slopes: + rtol = 1e0 + numpy_allclose(out_ref, output, rtol=rtol, atol=atol) @@ -277,5 +316,5 @@ if __name__ == "__main__": dtype, ) in test_combinations: test_flash_decoding_attention( - batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype + batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype, True ) diff --git a/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py b/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py index 3fa17037f..e9c99ddc7 100644 --- a/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py +++ b/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py @@ -4,12 +4,40 @@ import torch.nn.functional as F from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.utils import get_current_device -from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2 -from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data +from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v3, mock_alloc_single_token inference_ops = InferenceOpsLoader().load() -HEAD_DIM = 4 +HEAD_DIM = 72 + + +def prepare_data( + bsz, + num_kv_heads, + block_size, + max_num_blocks_per_seq, + context_lengths, + device="cuda", + dtype=torch.float16, +): + num_tokens = torch.sum(context_lengths).item() + + max_seq_len_in_batch = context_lengths.max() + cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) + + kv_size = (num_tokens, num_kv_heads, HEAD_DIM) + key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + + k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v3( + key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + ) + + block_tables = block_tables.to(device=device) + k_cache = torch.zeros_like(k_cache_ref) + v_cache = torch.zeros_like(v_cache_ref) + + return key, value, k_cache, v_cache, cu_seqlens, block_tables, max_seq_len_in_batch, k_cache_ref, v_cache_ref def run_decode_copy_kv_to_caches( @@ -24,32 +52,41 @@ def run_decode_copy_kv_to_caches( torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() + n = 1 + max_seq_len = block_size * max_num_blocks_per_seq dtype = torch.float32 device = get_current_device() - new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data( - bsz, - num_kv_heads, - HEAD_DIM, - block_size, - max_num_blocks_per_seq, - same_context_len, - max_seq_len, - device=device, - dtype=dtype, + assert max_seq_len > n, "max_seq_len must be greater than n" + + past_kv_seq_lengths = ( + torch.tensor([max_seq_len - n for _ in range(bsz)], dtype=torch.int32, device=device) + if same_context_len + else torch.randint(low=1, high=max_seq_len - n, size=(bsz,), dtype=torch.int32, device=device) ) - new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k - new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v - inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables) + key, value, k_cache, v_cache, _, block_tables, _, _, _ = prepare_data( + bsz, num_kv_heads, block_size, max_num_blocks_per_seq, past_kv_seq_lengths, device, dtype + ) - past_kv_seq_len = kv_seq_lengths - 1 + new_k = torch.randn((bsz, num_kv_heads, HEAD_DIM), dtype=dtype, device=device) + new_v = torch.randn((bsz, num_kv_heads, HEAD_DIM), dtype=dtype, device=device) + + # mock allocating blocks for the new k/v and update block tables + for _ in range(n): + mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size) + past_kv_seq_lengths += 1 + + inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, past_kv_seq_lengths, block_tables) + + past_kv_seq_len = past_kv_seq_lengths - 1 target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size] offsets_in_block = past_kv_seq_len % block_size - k_target = k_cache[target_block_ids, :, offsets_in_block, :] + k_target = k_cache[target_block_ids, :, :, offsets_in_block, :] k_source = new_k.squeeze() v_target = v_cache[target_block_ids, :, offsets_in_block, :] + k_target = k_target.reshape(v_target.shape) v_source = new_v.squeeze() assert k_target.shape == k_source.shape @@ -77,22 +114,17 @@ def run_context_copy_kv_to_cache( else: context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device) - num_tokens = torch.sum(context_lengths).item() - - max_seq_len_in_batch = context_lengths.max() - cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) - - kv_size = (num_tokens, num_kv_heads, HEAD_DIM) - key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) - value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) - - k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2( - key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device - ) - - block_tables = block_tables.to(device=device) - k_cache = torch.zeros_like(k_cache_ref) - v_cache = torch.zeros_like(v_cache_ref) + ( + key, + value, + k_cache, + v_cache, + cu_seqlens, + block_tables, + max_seq_len_in_batch, + k_cache_ref, + v_cache_ref, + ) = prepare_data(bsz, num_kv_heads, block_size, max_num_blocks_per_seq, context_lengths, device, dtype) inference_ops.context_kv_cache_memcpy( key, value, k_cache, v_cache, context_lengths, cu_seqlens, block_tables, max_seq_len_in_batch diff --git a/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py index 6f5d0ac84..501bf65d8 100644 --- a/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py @@ -7,7 +7,7 @@ from colossalai.kernel.kernel_loader import InferenceOpsLoader inference_ops = InferenceOpsLoader().load() -from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2 +from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v3 from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb @@ -49,12 +49,14 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype): cos_shape = (TOTAL_TOKENS, D // 2) cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") - cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, block_size, D) - k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda") + x = 16 // torch.tensor([], dtype=dtype).element_size() + k_cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, D // x, block_size, x) + v_cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, block_size, D) + k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device="cuda") v = torch.randn_like(k) - v_cache = torch.zeros_like(k_cache) + v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device="cuda") past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda") - block_tables = mock_alloc_block_table_and_kvcache_v2( + block_tables = mock_alloc_block_table_and_kvcache_v3( k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_blocks_per_sequence, block_size ) new_k = torch.randn((BATCH_SIZE, K_H, D), dtype=dtype, device="cuda") @@ -97,9 +99,10 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype): past_kv_seq_len = kv_seq_lengths - 1 target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size] offsets_in_block = past_kv_seq_len % block_size - k_target = k_cache[target_block_ids, :, offsets_in_block, :].squeeze() + k_target = k_cache[target_block_ids, :, :, offsets_in_block, :].squeeze() k_source = new_k_copy.squeeze() v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze() + k_target = k_target.reshape(v_target.shape) v_source = new_v.squeeze() numpy_allclose(new_q, q_ref, rtol=rtol, atol=atol)