@ -1,6 +1,6 @@
/*This code adapted from vllm:
/*This code adapted from vllm:
* https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cu
* https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cu
* with different kvcache layout. */
*/
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
@ -50,7 +50,7 @@ template<typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, int
__global__ void flash_decoding_attention_kernel(
__global__ void flash_decoding_attention_kernel(
scalar_t* __restrict__ out, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ out, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_tokens, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, block_size, head_size ]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x ]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const int* __restrict__ context_lens, // [num_tokens]
const int* __restrict__ context_lens, // [num_tokens]
const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq]
const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq]
@ -70,15 +70,19 @@ __global__ void flash_decoding_attention_kernel(
const int num_queries_per_kv = num_heads / num_kv_heads;
const int num_queries_per_kv = num_heads / num_kv_heads;
const int kv_head_idx = head_idx / num_queries_per_kv;
const int kv_head_idx = head_idx / num_queries_per_kv;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
constexpr int Q_SHARED_SIZE = (HEAD_SIZE * sizeof(scalar_t)) / sizeof(float4);
constexpr int x = sizeof(float4) / sizeof(scalar_t);
constexpr int Q_SHARED_SIZE = HEAD_SIZE / x;
// here thread_group does not determine the number of threads responsible for a key
// here thread_group does not determine the number of threads responsible for a key
// but only the VEC_SIZE of each thread
// but only the VEC_SIZE of each thread
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int VEC_SIZE = MIN(ROUND_DOWN_HIGHEST_POWER_OF_TWO((HEAD_SIZE / THREAD_GROUP_SIZE)), sizeof(float4) / sizeof(scalar_t) );
constexpr int VEC_SIZE = MIN(ROUND_DOWN_HIGHEST_POWER_OF_TWO((HEAD_SIZE / THREAD_GROUP_SIZE)), x );
constexpr int NUM_VECS_PER_TOKEN = HEAD_SIZE / VEC_SIZE;
constexpr int NUM_VECS_PER_TOKEN = HEAD_SIZE / VEC_SIZE;
constexpr int NUM_THREADS_PER_TOKEN = MIN(NUM_VECS_PER_TOKEN, WARP_SIZE);
constexpr int NUM_THREADS_PER_TOKEN = MIN(NUM_VECS_PER_TOKEN, WARP_SIZE);
constexpr int NUM_ROUNDS_PER_TOKEN = NUM_VECS_PER_TOKEN / NUM_THREADS_PER_TOKEN;
constexpr int NUM_ROUNDS_PER_TOKEN = NUM_VECS_PER_TOKEN / NUM_THREADS_PER_TOKEN;
constexpr int WARP_STRIDE = WARP_SIZE * NUM_ROUNDS_PER_TOKEN;
constexpr int WARP_STRIDE = WARP_SIZE * NUM_ROUNDS_PER_TOKEN;
constexpr int NUM_THREADS_PER_X = x / VEC_SIZE;
constexpr int NUM_ROWS_PER_ROUNDS = MIN(WARP_SIZE / NUM_THREADS_PER_X, BLOCK_SIZE);
constexpr int NUM_VECS_PER_THREAD = NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN / WARP_SIZE;
using K_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
using K_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
using V_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
using V_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
@ -86,15 +90,17 @@ __global__ void flash_decoding_attention_kernel(
using Float_vec = typename FloatVecTypeTrait<L_vec>::Type;
using Float_vec = typename FloatVecTypeTrait<L_vec>::Type;
const int context_len = context_lens[seq_idx];
const int context_len = context_lens[seq_idx];
const int thread_group_offset = thread_idx % NUM_THREADS_PER_TOKEN ;
const int thread_group_offset = lane % NUM_THREADS_PER_X ;
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const int shared_memory_offset = DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4);
__shared__ float4 q_shared[Q_SHARED_SIZE];
__shared__ float4 q_shared[Q_SHARED_SIZE];
__shared__ float red_shared_mem[2 * NUM_WARPS];
__shared__ float red_shared_mem[2 * NUM_WARPS];
extern __shared__ char shared_mem[];
extern __shared__ char shared_mem[];
float* logits = reinterpret_cast<float*>(shared_mem);
int* block_table_shared = reinterpret_cast<int*>(shared_mem);
float* out_shared_mem = reinterpret_cast<float*>(shared_mem);
float* logits = reinterpret_cast<float*>(shared_mem + shared_memory_offset);
float* out_shared_mem = reinterpret_cast<float*>(shared_mem + shared_memory_offset);
float qk_max = -FLT_MAX;
float qk_max = -FLT_MAX;
const float4* q_ptr = reinterpret_cast<const float4*>(q + seq_idx * q_stride + head_idx * HEAD_SIZE);
const float4* q_ptr = reinterpret_cast<const float4*>(q + seq_idx * q_stride + head_idx * HEAD_SIZE);
@ -102,32 +108,47 @@ __global__ void flash_decoding_attention_kernel(
for (int idx = thread_idx; idx < Q_SHARED_SIZE; idx += blockDim.x) {
for (int idx = thread_idx; idx < Q_SHARED_SIZE; idx += blockDim.x) {
q_shared[idx] = q_ptr[idx];
q_shared[idx] = q_ptr[idx];
}
}
#pragma unroll
for (int idx = thread_idx; idx < max_num_blocks_per_seq; idx += blockDim.x) {
block_table_shared[idx] = block_table[idx];
}
__syncthreads();
__syncthreads();
scalar_t* q_shared_ptr = reinterpret_cast<scalar_t*>(q_shared);
scalar_t* q_shared_ptr = reinterpret_cast<scalar_t*>(q_shared);
// each warp access a whole block
// each warp access a whole block
K_vec q_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int idx = lane, i = 0; idx < NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN; idx += WARP_SIZE, i += 1) {
const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS;
const int offset1 = idx % NUM_THREADS_PER_X;
q_vecs[i] = *reinterpret_cast<K_vec*>(q_shared_ptr + offset0 * x + offset1 * VEC_SIZE);
}
for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
const int64_t physical_block_number = static_cast<int64_t>(block_table_shared[block_idx]);
K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
#pragma unroll
for (int idx = lane; idx < BLOCK_SIZE * NUM_VECS_PER_TOKEN; idx += WARP_STRIDE) {
for (int i = 0; i < BLOCK_SIZE; i += NUM_ROWS_PER_ROUNDS) {
const int token_idx = block_idx * BLOCK_SIZE + idx / NUM_VECS_PER_TOKEN;
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride
+ kv_head_idx * kv_head_stride
+ idx * VEC_SIZE;
+ i * x;
K_vec k_vecs[NUM_ROUNDS_PER_TOKEN];
K_vec q_vecs[NUM_ROUNDS_PER_TOKEN];
// we must calculate at least one row of hidden vectors
#pragma unroll
#pragma unroll
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
for (int idx = lane, j = 0; idx < NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN; idx += WARP_SIZE, j += 1) {
k_vecs[i] = (reinterpret_cast<const K_vec*>(k_ptr))[i * WARP_SIZE];
const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS;
q_vecs[i] = (reinterpret_cast<K_vec*>(q_shared_ptr))[(idx + i * WARP_SIZE) % NUM_VECS_PER_TOKEN];
const int offset1 = (idx / NUM_THREADS_PER_X) % NUM_ROWS_PER_ROUNDS;
const int offset2 = idx % NUM_THREADS_PER_X;
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset0 * BLOCK_SIZE * x + offset1 * x + offset2 * VEC_SIZE);
}
}
float qk = scale * Qk_dot<scalar_t, NUM_THREADS_PER_TOKEN >::dot(q_vecs, k_vecs);
float qk = scale * Qk_dot<scalar_t, NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X, NUM_THREADS_PER_X >::dot(q_vecs, k_vecs);
if (thread_group_offset == 0) {
if (thread_group_offset == 0 && lane < NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X) {
const int token_idx = block_idx * BLOCK_SIZE + i * NUM_ROWS_PER_ROUNDS + lane / NUM_THREADS_PER_X;
const bool mask = token_idx >= context_len;
const bool mask = token_idx >= context_len;
logits[token_idx] = mask ? 0.f : qk;
logits[token_idx] = mask ? 0.f : qk;
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
@ -136,7 +157,7 @@ __global__ void flash_decoding_attention_kernel(
}
}
// there exists a __syncthreads within this function
// there exists a __syncthreads within this function
qk_max = block_max<NUM_WARPS, NUM_THREADS_PER_TOKEN >(red_shared_mem, qk_max);
qk_max = block_max<NUM_WARPS, NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X, NUM_THREADS_PER_X >(red_shared_mem, qk_max);
// Get the sum of the exp values.
// Get the sum of the exp values.
float exp_sum = 0.f;
float exp_sum = 0.f;
@ -162,7 +183,7 @@ __global__ void flash_decoding_attention_kernel(
V_vec zero_value;
V_vec zero_value;
zero(zero_value);
zero(zero_value);
for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
const int64_t physical_block_number = static_cast<int64_t>(block_table_shared [block_idx]);
scalar_t logit;
scalar_t logit;
#pragma unroll
#pragma unroll
@ -241,7 +262,7 @@ template<
void flash_decoding_attention_v1_launcher(
void flash_decoding_attention_v1_launcher(
torch::Tensor& out, // [num_tokens, num_heads, head_size]
torch::Tensor& out, // [num_tokens, num_heads, head_size]
torch::Tensor& query, // [num_tokens, num_heads, head_size]
torch::Tensor& query, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_kv_heads, block_size, head_size ]
torch::Tensor& key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x ]
torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor& context_lens, // [num_tokens]
torch::Tensor& context_lens, // [num_tokens]
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
@ -266,7 +287,7 @@ void flash_decoding_attention_v1_launcher(
int logits_size = padded_max_context_len * sizeof(float);
int logits_size = padded_max_context_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * NUM_THREADS_PER_TOKEN * VEC_SIZE * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * NUM_THREADS_PER_TOKEN * VEC_SIZE * sizeof(float);
// Keep that in sync with the logic here!
// Keep that in sync with the logic here!
int shared_mem_size = std::max(logits_size, outputs_size);
int shared_mem_size = std::max(logits_size, outputs_size) + DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4) ;
dim3 grid(num_heads, num_tokens, 1);
dim3 grid(num_heads, num_tokens, 1);
dim3 block(NUM_THREADS);
dim3 block(NUM_THREADS);
@ -323,7 +344,7 @@ void flash_decoding_attention_v1_launcher(
void flash_decoding_attention(
void flash_decoding_attention(
torch::Tensor& out, // [num_tokens, num_heads, head_size]
torch::Tensor& out, // [num_tokens, num_heads, head_size]
torch::Tensor& query, // [num_tokens, num_heads, head_size]
torch::Tensor& query, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_kv_heads, block_size, head_size ]
torch::Tensor& key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x ]
torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor& context_lens, // [num_tokens]
torch::Tensor& context_lens, // [num_tokens]
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]