diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index ff5a159cd..8249eafcf 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -593,7 +593,7 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention): high_precision, ) # inference_ops.flash_decoding_attention( - # attn_output, + # output_tensor, # query_states, # k_cache, # v_cache, @@ -605,6 +605,7 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention): # fd_inter_tensor.mid_output_lse, # sm_scale, # ) + # attn_output = output_tensor else: if is_verifier: rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) diff --git a/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py b/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py index e33d9a9dc..1a18ffa2e 100644 --- a/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py +++ b/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py @@ -5,6 +5,7 @@ from colossalai.kernel.triton import flash_decoding_attention from colossalai.utils import get_current_device from tests.test_infer.test_ops.triton.kernel_utils import ( generate_caches_and_block_tables_v2, + generate_caches_and_block_tables_v3, generate_caches_and_block_tables_vllm, ) @@ -95,7 +96,11 @@ def benchmark_flash_decoding_attention( BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device ) - k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( + triton_k_cache, triton_v_cache, _ = generate_caches_and_block_tables_v2( + k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device + ) + + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3( k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device ) @@ -135,8 +140,8 @@ def benchmark_flash_decoding_attention( elif provider == "triton_flash_decoding_attention": fn = lambda: flash_decoding_attention( q.squeeze(2), - k_cache, - v_cache, + triton_k_cache, + triton_v_cache, kv_seq_lengths, block_tables, BLOCK_SIZE, diff --git a/extensions/csrc/kernel/cuda/attention/attention_utils.h b/extensions/csrc/kernel/cuda/attention/attention_utils.h index fa555fdc8..732936809 100644 --- a/extensions/csrc/kernel/cuda/attention/attention_utils.h +++ b/extensions/csrc/kernel/cuda/attention/attention_utils.h @@ -41,7 +41,8 @@ namespace attention { #define SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) // Q*K^T operation. -template +template inline __device__ float qk_dot_(const VecT (&q)[N], const VecT (&k)[N]) { using A_vec = typename common::FloatVecTypeTrait::Type; // Compute the parallel products for Q*K^T (treat vector lanes separately). @@ -58,21 +59,27 @@ inline __device__ float qk_dot_(const VecT (&q)[N], const VecT (&k)[N]) { // Finalize the reduction across lanes. float qk = sum_vect(qk_vec); #pragma unroll - for (int mask = (NUM_THREADS_PER_TOKEN >> 1); mask > 0; mask >>= 1) { + for (int mask = (WARP_SIZE >> 1); mask >= NUM_THREADS_PER_ROUNDS; + mask >>= 1) { + qk += SHFL_XOR_SYNC(qk, mask); + } + +#pragma unroll + for (int mask = (NUM_THREADS_PER_X >> 1); mask > 0; mask >>= 1) { qk += SHFL_XOR_SYNC(qk, mask); } return qk; } -template +template struct Qk_dot { template static inline __device__ float dot(const VecT (&q)[N], const VecT (&k)[N]) { - return qk_dot_(q, k); + return qk_dot_(q, k); } }; -template +template inline __device__ float block_max(float* red_smem, float max) { int warp = threadIdx.x >> 5; int lane = threadIdx.x & 0x1f; @@ -81,7 +88,8 @@ inline __device__ float block_max(float* red_smem, float max) { // for each warp, the 1st out of NUM_THREADS_PER_TOKEN thread already has the // max value among every NUM_THREADS_PER_TOKEN threads. #pragma unroll - for (int mask = (WARP_SIZE >> 1); mask >= NUM_THREADS_PER_TOKEN; mask >>= 1) { + for (int mask = (NUM_THREADS_PER_ROUNDS >> 1); mask >= NUM_THREADS_PER_X; + mask >>= 1) { max = fmaxf(max, SHFL_XOR_SYNC(max, mask)); } @@ -155,10 +163,12 @@ inline __device__ void block_sum(float* red_smem, VecT& acc) { if (lane < NUM_THREADS_PER_GROUP) { if constexpr (N == VEC_SIZE_8) { VecT* vdst = &((reinterpret_cast(dst))[lane]); - (reinterpret_cast(vdst))[0] = - (reinterpret_cast(acc_ptr))[0]; - (reinterpret_cast(vdst))[1] = - (reinterpret_cast(acc_ptr))[1]; + const int idx0 = (lane >> 2) & 0x1; + const int idx1 = idx0 ^ 0x1; + (reinterpret_cast(vdst))[idx0] = + (reinterpret_cast(acc_ptr))[idx0]; + (reinterpret_cast(vdst))[idx1] = + (reinterpret_cast(acc_ptr))[idx1]; } else { (reinterpret_cast(dst))[lane] = acc; } @@ -173,10 +183,12 @@ inline __device__ void block_sum(float* red_smem, VecT& acc) { float* src_ptr = reinterpret_cast(&src_reg); if constexpr (N == VEC_SIZE_8) { VecT* vsrc = &((reinterpret_cast(src))[lane]); - (reinterpret_cast(src_ptr))[0] = - (reinterpret_cast(vsrc))[0]; - (reinterpret_cast(src_ptr))[1] = - (reinterpret_cast(vsrc))[1]; + const int idx0 = (lane >> 2) & 0x1; + const int idx1 = idx0 ^ 0x1; + (reinterpret_cast(src_ptr))[idx0] = + (reinterpret_cast(vsrc))[idx0]; + (reinterpret_cast(src_ptr))[idx1] = + (reinterpret_cast(vsrc))[idx1]; } else { src_reg = (reinterpret_cast(src))[lane]; } diff --git a/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu b/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu index 8930ba04c..a004a98c3 100644 --- a/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu +++ b/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu @@ -1,6 +1,6 @@ /*This code adapted from vllm: * https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cu - * with different kvcache layout. */ + */ #include #include @@ -50,7 +50,7 @@ template::Type; using V_vec = typename VecTypeTrait::Type; @@ -86,15 +90,17 @@ __global__ void flash_decoding_attention_kernel( using Float_vec = typename FloatVecTypeTrait::Type; const int context_len = context_lens[seq_idx]; - const int thread_group_offset = thread_idx % NUM_THREADS_PER_TOKEN; + const int thread_group_offset = lane % NUM_THREADS_PER_X; const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + const int shared_memory_offset = DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4); __shared__ float4 q_shared[Q_SHARED_SIZE]; __shared__ float red_shared_mem[2 * NUM_WARPS]; extern __shared__ char shared_mem[]; - float* logits = reinterpret_cast(shared_mem); - float* out_shared_mem = reinterpret_cast(shared_mem); + int* block_table_shared = reinterpret_cast(shared_mem); + float* logits = reinterpret_cast(shared_mem + shared_memory_offset); + float* out_shared_mem = reinterpret_cast(shared_mem + shared_memory_offset); float qk_max = -FLT_MAX; const float4* q_ptr = reinterpret_cast(q + seq_idx * q_stride + head_idx * HEAD_SIZE); @@ -102,32 +108,47 @@ __global__ void flash_decoding_attention_kernel( for (int idx = thread_idx; idx < Q_SHARED_SIZE; idx += blockDim.x) { q_shared[idx] = q_ptr[idx]; } + + #pragma unroll + for (int idx = thread_idx; idx < max_num_blocks_per_seq; idx += blockDim.x) { + block_table_shared[idx] = block_table[idx]; + } + __syncthreads(); scalar_t* q_shared_ptr = reinterpret_cast(q_shared); // each warp access a whole block + + K_vec q_vecs[NUM_VECS_PER_THREAD]; + #pragma unroll + for (int idx = lane, i = 0; idx < NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN; idx += WARP_SIZE, i += 1) { + const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS; + const int offset1 = idx % NUM_THREADS_PER_X; + q_vecs[i] = *reinterpret_cast(q_shared_ptr + offset0 * x + offset1 * VEC_SIZE); + } + for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) { - const int64_t physical_block_number = static_cast(block_table[block_idx]); + const int64_t physical_block_number = static_cast(block_table_shared[block_idx]); + + K_vec k_vecs[NUM_VECS_PER_THREAD]; + #pragma unroll - for (int idx = lane; idx < BLOCK_SIZE * NUM_VECS_PER_TOKEN; idx += WARP_STRIDE) { - const int token_idx = block_idx * BLOCK_SIZE + idx / NUM_VECS_PER_TOKEN; + for (int i = 0; i < BLOCK_SIZE; i += NUM_ROWS_PER_ROUNDS) { const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride - + idx * VEC_SIZE; - - K_vec k_vecs[NUM_ROUNDS_PER_TOKEN]; - K_vec q_vecs[NUM_ROUNDS_PER_TOKEN]; - - // we must calculate at least one row of hidden vectors + + i * x; #pragma unroll - for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { - k_vecs[i] = (reinterpret_cast(k_ptr))[i * WARP_SIZE]; - q_vecs[i] = (reinterpret_cast(q_shared_ptr))[(idx + i * WARP_SIZE) % NUM_VECS_PER_TOKEN]; + for (int idx = lane, j = 0; idx < NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN; idx += WARP_SIZE, j += 1) { + const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS; + const int offset1 = (idx / NUM_THREADS_PER_X) % NUM_ROWS_PER_ROUNDS; + const int offset2 = idx % NUM_THREADS_PER_X; + k_vecs[j] = *reinterpret_cast(k_ptr + offset0 * BLOCK_SIZE * x + offset1 * x + offset2 * VEC_SIZE); } - float qk = scale * Qk_dot::dot(q_vecs, k_vecs); + float qk = scale * Qk_dot::dot(q_vecs, k_vecs); - if (thread_group_offset == 0) { + if (thread_group_offset == 0 && lane < NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X) { + const int token_idx = block_idx * BLOCK_SIZE + i * NUM_ROWS_PER_ROUNDS + lane / NUM_THREADS_PER_X; const bool mask = token_idx >= context_len; logits[token_idx] = mask ? 0.f : qk; qk_max = mask ? qk_max : fmaxf(qk_max, qk); @@ -136,7 +157,7 @@ __global__ void flash_decoding_attention_kernel( } // there exists a __syncthreads within this function - qk_max = block_max(red_shared_mem, qk_max); + qk_max = block_max(red_shared_mem, qk_max); // Get the sum of the exp values. float exp_sum = 0.f; @@ -162,7 +183,7 @@ __global__ void flash_decoding_attention_kernel( V_vec zero_value; zero(zero_value); for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) { - const int64_t physical_block_number = static_cast(block_table[block_idx]); + const int64_t physical_block_number = static_cast(block_table_shared[block_idx]); scalar_t logit; #pragma unroll @@ -241,7 +262,7 @@ template< void flash_decoding_attention_v1_launcher( torch::Tensor& out, // [num_tokens, num_heads, head_size] torch::Tensor& query, // [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size] torch::Tensor& context_lens, // [num_tokens] torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq] @@ -266,7 +287,7 @@ void flash_decoding_attention_v1_launcher( int logits_size = padded_max_context_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * NUM_THREADS_PER_TOKEN * VEC_SIZE * sizeof(float); // Keep that in sync with the logic here! - int shared_mem_size = std::max(logits_size, outputs_size); + int shared_mem_size = std::max(logits_size, outputs_size) + DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4); dim3 grid(num_heads, num_tokens, 1); dim3 block(NUM_THREADS); @@ -323,7 +344,7 @@ void flash_decoding_attention_v1_launcher( void flash_decoding_attention( torch::Tensor& out, // [num_tokens, num_heads, head_size] torch::Tensor& query, // [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size] torch::Tensor& context_lens, // [num_tokens] torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq] diff --git a/extensions/csrc/kernel/cuda/rms_layernorm_kernel.cu b/extensions/csrc/kernel/cuda/rms_layernorm_kernel.cu index 0cd330b5f..c9bd3d72d 100644 --- a/extensions/csrc/kernel/cuda/rms_layernorm_kernel.cu +++ b/extensions/csrc/kernel/cuda/rms_layernorm_kernel.cu @@ -287,7 +287,7 @@ void rms_layernorm( RMSNORM_LAUNCHER(8, block); break; default: - AT_ERROR("unroll_factor must be 1, 2, 4 or 8"); + AT_ERROR("unroll_factor must be 1, 2, 3, 4 or 8"); } } } @@ -334,7 +334,7 @@ void fused_add_rms_layernorm( FUSED_ADD_RMSNORM_LAUNCHER(8, block); break; default: - AT_ERROR("unroll_factor must be 1, 2, 4 or 8"); + AT_ERROR("unroll_factor must be 1, 2, 3, 4 or 8"); } } } diff --git a/extensions/pybind/inference/inference.cpp b/extensions/pybind/inference/inference.cpp index 9997cc54c..0604d4c71 100644 --- a/extensions/pybind/inference/inference.cpp +++ b/extensions/pybind/inference/inference.cpp @@ -62,7 +62,7 @@ void flash_decoding_attention( torch::Tensor& out, // [num_tokens, num_heads, head_size] torch::Tensor& query, // [num_tokens, num_heads, head_size] torch::Tensor& - key_cache, // [num_blocks, num_kv_heads, block_size, head_size] + key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size] torch::Tensor& context_lens, // [num_tokens] diff --git a/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py b/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py index f641a9102..babd6595c 100644 --- a/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py +++ b/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py @@ -12,7 +12,7 @@ inference_ops = InferenceOpsLoader().load() from tests.test_infer.test_ops.triton.kernel_utils import ( convert_kv_unpad_to_padded, create_attention_mask, - generate_caches_and_block_tables_v2, + generate_caches_and_block_tables_v3, generate_caches_and_block_tables_vllm, torch_attn_ref, ) @@ -77,7 +77,7 @@ def test_flash_decoding_attention( BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device ) - k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3( k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device ) diff --git a/tests/test_infer/test_ops/triton/kernel_utils.py b/tests/test_infer/test_ops/triton/kernel_utils.py index 507c185b5..6bb947d00 100644 --- a/tests/test_infer/test_ops/triton/kernel_utils.py +++ b/tests/test_infer/test_ops/triton/kernel_utils.py @@ -150,6 +150,50 @@ def mock_alloc_block_table_and_kvcache_v2( return block_tables +def mock_alloc_block_table_and_kvcache_v3( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + context_lengths: torch.Tensor, + num_seqs: int, + max_num_blocks_per_seq: int, + block_size: int, +) -> torch.Tensor: + """Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache.""" + block_id = 0 + block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32) + num_tokens_processed = 0 + + _, num_kv_heads, head_dim = k.shape + + x = 16 // torch.tensor([], dtype=k.dtype).element_size() + + for i, seq_len in enumerate(context_lengths.tolist()): + right_bound = (seq_len + block_size - 1) // block_size # open bound + block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32) + # Manually fill kv caches by copying from k and v + for i in range(right_bound): + if i == right_bound - 1: + allocated_locs = seq_len % block_size or block_size + else: + allocated_locs = block_size + # [block_size, num_kv_heads, head_dim/x, x]->[num_kv_heads, head_dim/x, block_size,x] + k_block = ( + k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :] + .reshape(allocated_locs, num_kv_heads, head_dim // x, x) + .permute(1, 2, 0, 3) + ) + v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2) + k_cache[block_id, :, :, :allocated_locs, :] = k_block + v_cache[block_id, :, :allocated_locs, :] = v_block + + num_tokens_processed += allocated_locs + block_id += 1 + + return block_tables + + def mock_alloc_block_table_and_kvcache_vllm( k: torch.Tensor, v: torch.Tensor, @@ -251,6 +295,26 @@ def generate_caches_and_block_tables_v2( return k_cache, v_cache, block_tables +def generate_caches_and_block_tables_v3( + k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda" +) -> Tuple[torch.Tensor, ...]: + # Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths + # k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim] + _, num_kv_heads, head_dim = k_unpad.shape + + x = 16 // torch.tensor([], dtype=dtype).element_size() + + k_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x) + v_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim) + k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=device) + v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device=device) + # Mock allocation on block tables as well as blocked kv caches + block_tables = mock_alloc_block_table_and_kvcache_v3( + k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size + ) + return k_cache, v_cache, block_tables + + def generate_caches_and_block_tables_vllm( k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda" ) -> Tuple[torch.Tensor, ...]: