mirror of https://github.com/hpcaitech/ColossalAI
[Inference/Kernel] Optimize paged attention: Refactor key cache layout (#5643)
* optimize flashdecodingattention: refactor code with different key cache layout(from [num_blocks, num_kv_heads, block_size, head_size] to [num_blocks, num_kv_heads, head_size/x, block_size, x]) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5656/head
parent
90cd5227a3
commit
a8fd3b0342
|
@ -593,7 +593,7 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
|
||||||
high_precision,
|
high_precision,
|
||||||
)
|
)
|
||||||
# inference_ops.flash_decoding_attention(
|
# inference_ops.flash_decoding_attention(
|
||||||
# attn_output,
|
# output_tensor,
|
||||||
# query_states,
|
# query_states,
|
||||||
# k_cache,
|
# k_cache,
|
||||||
# v_cache,
|
# v_cache,
|
||||||
|
@ -605,6 +605,7 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
|
||||||
# fd_inter_tensor.mid_output_lse,
|
# fd_inter_tensor.mid_output_lse,
|
||||||
# sm_scale,
|
# sm_scale,
|
||||||
# )
|
# )
|
||||||
|
# attn_output = output_tensor
|
||||||
else:
|
else:
|
||||||
if is_verifier:
|
if is_verifier:
|
||||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||||
|
|
|
@ -5,6 +5,7 @@ from colossalai.kernel.triton import flash_decoding_attention
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||||
generate_caches_and_block_tables_v2,
|
generate_caches_and_block_tables_v2,
|
||||||
|
generate_caches_and_block_tables_v3,
|
||||||
generate_caches_and_block_tables_vllm,
|
generate_caches_and_block_tables_vllm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -95,7 +96,11 @@ def benchmark_flash_decoding_attention(
|
||||||
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
|
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
|
||||||
)
|
)
|
||||||
|
|
||||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
triton_k_cache, triton_v_cache, _ = generate_caches_and_block_tables_v2(
|
||||||
|
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
|
||||||
|
)
|
||||||
|
|
||||||
|
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
|
||||||
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
|
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -135,8 +140,8 @@ def benchmark_flash_decoding_attention(
|
||||||
elif provider == "triton_flash_decoding_attention":
|
elif provider == "triton_flash_decoding_attention":
|
||||||
fn = lambda: flash_decoding_attention(
|
fn = lambda: flash_decoding_attention(
|
||||||
q.squeeze(2),
|
q.squeeze(2),
|
||||||
k_cache,
|
triton_k_cache,
|
||||||
v_cache,
|
triton_v_cache,
|
||||||
kv_seq_lengths,
|
kv_seq_lengths,
|
||||||
block_tables,
|
block_tables,
|
||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
|
|
|
@ -41,7 +41,8 @@ namespace attention {
|
||||||
#define SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
|
#define SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
|
||||||
|
|
||||||
// Q*K^T operation.
|
// Q*K^T operation.
|
||||||
template <int NUM_THREADS_PER_TOKEN, typename VecT, int N>
|
template <int NUM_THREADS_PER_ROUNDS, int NUM_THREADS_PER_X, typename VecT,
|
||||||
|
int N>
|
||||||
inline __device__ float qk_dot_(const VecT (&q)[N], const VecT (&k)[N]) {
|
inline __device__ float qk_dot_(const VecT (&q)[N], const VecT (&k)[N]) {
|
||||||
using A_vec = typename common::FloatVecTypeTrait<VecT>::Type;
|
using A_vec = typename common::FloatVecTypeTrait<VecT>::Type;
|
||||||
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
||||||
|
@ -58,21 +59,27 @@ inline __device__ float qk_dot_(const VecT (&q)[N], const VecT (&k)[N]) {
|
||||||
// Finalize the reduction across lanes.
|
// Finalize the reduction across lanes.
|
||||||
float qk = sum_vect(qk_vec);
|
float qk = sum_vect(qk_vec);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = (NUM_THREADS_PER_TOKEN >> 1); mask > 0; mask >>= 1) {
|
for (int mask = (WARP_SIZE >> 1); mask >= NUM_THREADS_PER_ROUNDS;
|
||||||
|
mask >>= 1) {
|
||||||
|
qk += SHFL_XOR_SYNC(qk, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = (NUM_THREADS_PER_X >> 1); mask > 0; mask >>= 1) {
|
||||||
qk += SHFL_XOR_SYNC(qk, mask);
|
qk += SHFL_XOR_SYNC(qk, mask);
|
||||||
}
|
}
|
||||||
return qk;
|
return qk;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int NUM_THREADS_PER_TOKEN>
|
template <typename T, int NUM_THREADS_PER_ROUNDS, int NUM_THREADS_PER_X>
|
||||||
struct Qk_dot {
|
struct Qk_dot {
|
||||||
template <typename VecT, int N>
|
template <typename VecT, int N>
|
||||||
static inline __device__ float dot(const VecT (&q)[N], const VecT (&k)[N]) {
|
static inline __device__ float dot(const VecT (&q)[N], const VecT (&k)[N]) {
|
||||||
return qk_dot_<NUM_THREADS_PER_TOKEN>(q, k);
|
return qk_dot_<NUM_THREADS_PER_ROUNDS, NUM_THREADS_PER_X>(q, k);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int NUM_WARPS, int NUM_THREADS_PER_TOKEN>
|
template <int NUM_WARPS, int NUM_THREADS_PER_ROUNDS, int NUM_THREADS_PER_X>
|
||||||
inline __device__ float block_max(float* red_smem, float max) {
|
inline __device__ float block_max(float* red_smem, float max) {
|
||||||
int warp = threadIdx.x >> 5;
|
int warp = threadIdx.x >> 5;
|
||||||
int lane = threadIdx.x & 0x1f;
|
int lane = threadIdx.x & 0x1f;
|
||||||
|
@ -81,7 +88,8 @@ inline __device__ float block_max(float* red_smem, float max) {
|
||||||
// for each warp, the 1st out of NUM_THREADS_PER_TOKEN thread already has the
|
// for each warp, the 1st out of NUM_THREADS_PER_TOKEN thread already has the
|
||||||
// max value among every NUM_THREADS_PER_TOKEN threads.
|
// max value among every NUM_THREADS_PER_TOKEN threads.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = (WARP_SIZE >> 1); mask >= NUM_THREADS_PER_TOKEN; mask >>= 1) {
|
for (int mask = (NUM_THREADS_PER_ROUNDS >> 1); mask >= NUM_THREADS_PER_X;
|
||||||
|
mask >>= 1) {
|
||||||
max = fmaxf(max, SHFL_XOR_SYNC(max, mask));
|
max = fmaxf(max, SHFL_XOR_SYNC(max, mask));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,10 +163,12 @@ inline __device__ void block_sum(float* red_smem, VecT& acc) {
|
||||||
if (lane < NUM_THREADS_PER_GROUP) {
|
if (lane < NUM_THREADS_PER_GROUP) {
|
||||||
if constexpr (N == VEC_SIZE_8) {
|
if constexpr (N == VEC_SIZE_8) {
|
||||||
VecT* vdst = &((reinterpret_cast<VecT*>(dst))[lane]);
|
VecT* vdst = &((reinterpret_cast<VecT*>(dst))[lane]);
|
||||||
(reinterpret_cast<float4*>(vdst))[0] =
|
const int idx0 = (lane >> 2) & 0x1;
|
||||||
(reinterpret_cast<float4*>(acc_ptr))[0];
|
const int idx1 = idx0 ^ 0x1;
|
||||||
(reinterpret_cast<float4*>(vdst))[1] =
|
(reinterpret_cast<float4*>(vdst))[idx0] =
|
||||||
(reinterpret_cast<float4*>(acc_ptr))[1];
|
(reinterpret_cast<float4*>(acc_ptr))[idx0];
|
||||||
|
(reinterpret_cast<float4*>(vdst))[idx1] =
|
||||||
|
(reinterpret_cast<float4*>(acc_ptr))[idx1];
|
||||||
} else {
|
} else {
|
||||||
(reinterpret_cast<VecT*>(dst))[lane] = acc;
|
(reinterpret_cast<VecT*>(dst))[lane] = acc;
|
||||||
}
|
}
|
||||||
|
@ -173,10 +183,12 @@ inline __device__ void block_sum(float* red_smem, VecT& acc) {
|
||||||
float* src_ptr = reinterpret_cast<float*>(&src_reg);
|
float* src_ptr = reinterpret_cast<float*>(&src_reg);
|
||||||
if constexpr (N == VEC_SIZE_8) {
|
if constexpr (N == VEC_SIZE_8) {
|
||||||
VecT* vsrc = &((reinterpret_cast<VecT*>(src))[lane]);
|
VecT* vsrc = &((reinterpret_cast<VecT*>(src))[lane]);
|
||||||
(reinterpret_cast<float4*>(src_ptr))[0] =
|
const int idx0 = (lane >> 2) & 0x1;
|
||||||
(reinterpret_cast<float4*>(vsrc))[0];
|
const int idx1 = idx0 ^ 0x1;
|
||||||
(reinterpret_cast<float4*>(src_ptr))[1] =
|
(reinterpret_cast<float4*>(src_ptr))[idx0] =
|
||||||
(reinterpret_cast<float4*>(vsrc))[1];
|
(reinterpret_cast<float4*>(vsrc))[idx0];
|
||||||
|
(reinterpret_cast<float4*>(src_ptr))[idx1] =
|
||||||
|
(reinterpret_cast<float4*>(vsrc))[idx1];
|
||||||
} else {
|
} else {
|
||||||
src_reg = (reinterpret_cast<VecT*>(src))[lane];
|
src_reg = (reinterpret_cast<VecT*>(src))[lane];
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/*This code adapted from vllm:
|
/*This code adapted from vllm:
|
||||||
* https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cu
|
* https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cu
|
||||||
* with different kvcache layout. */
|
*/
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
@ -50,7 +50,7 @@ template<typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, int
|
||||||
__global__ void flash_decoding_attention_kernel(
|
__global__ void flash_decoding_attention_kernel(
|
||||||
scalar_t* __restrict__ out, // [num_tokens, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_tokens, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_tokens, num_heads, head_size]
|
||||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||||
const int* __restrict__ context_lens, // [num_tokens]
|
const int* __restrict__ context_lens, // [num_tokens]
|
||||||
const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq]
|
||||||
|
@ -70,15 +70,19 @@ __global__ void flash_decoding_attention_kernel(
|
||||||
const int num_queries_per_kv = num_heads / num_kv_heads;
|
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||||
const int kv_head_idx = head_idx / num_queries_per_kv;
|
const int kv_head_idx = head_idx / num_queries_per_kv;
|
||||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||||
constexpr int Q_SHARED_SIZE = (HEAD_SIZE * sizeof(scalar_t)) / sizeof(float4);
|
constexpr int x = sizeof(float4) / sizeof(scalar_t);
|
||||||
|
constexpr int Q_SHARED_SIZE = HEAD_SIZE / x;
|
||||||
// here thread_group does not determine the number of threads responsible for a key
|
// here thread_group does not determine the number of threads responsible for a key
|
||||||
// but only the VEC_SIZE of each thread
|
// but only the VEC_SIZE of each thread
|
||||||
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||||
constexpr int VEC_SIZE = MIN(ROUND_DOWN_HIGHEST_POWER_OF_TWO((HEAD_SIZE / THREAD_GROUP_SIZE)), sizeof(float4) / sizeof(scalar_t));
|
constexpr int VEC_SIZE = MIN(ROUND_DOWN_HIGHEST_POWER_OF_TWO((HEAD_SIZE / THREAD_GROUP_SIZE)), x);
|
||||||
constexpr int NUM_VECS_PER_TOKEN = HEAD_SIZE / VEC_SIZE;
|
constexpr int NUM_VECS_PER_TOKEN = HEAD_SIZE / VEC_SIZE;
|
||||||
constexpr int NUM_THREADS_PER_TOKEN = MIN(NUM_VECS_PER_TOKEN, WARP_SIZE);
|
constexpr int NUM_THREADS_PER_TOKEN = MIN(NUM_VECS_PER_TOKEN, WARP_SIZE);
|
||||||
constexpr int NUM_ROUNDS_PER_TOKEN = NUM_VECS_PER_TOKEN / NUM_THREADS_PER_TOKEN;
|
constexpr int NUM_ROUNDS_PER_TOKEN = NUM_VECS_PER_TOKEN / NUM_THREADS_PER_TOKEN;
|
||||||
constexpr int WARP_STRIDE = WARP_SIZE * NUM_ROUNDS_PER_TOKEN;
|
constexpr int WARP_STRIDE = WARP_SIZE * NUM_ROUNDS_PER_TOKEN;
|
||||||
|
constexpr int NUM_THREADS_PER_X = x / VEC_SIZE;
|
||||||
|
constexpr int NUM_ROWS_PER_ROUNDS = MIN(WARP_SIZE / NUM_THREADS_PER_X, BLOCK_SIZE);
|
||||||
|
constexpr int NUM_VECS_PER_THREAD = NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN / WARP_SIZE;
|
||||||
|
|
||||||
using K_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
|
using K_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
|
||||||
using V_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
|
using V_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
|
||||||
|
@ -86,15 +90,17 @@ __global__ void flash_decoding_attention_kernel(
|
||||||
using Float_vec = typename FloatVecTypeTrait<L_vec>::Type;
|
using Float_vec = typename FloatVecTypeTrait<L_vec>::Type;
|
||||||
|
|
||||||
const int context_len = context_lens[seq_idx];
|
const int context_len = context_lens[seq_idx];
|
||||||
const int thread_group_offset = thread_idx % NUM_THREADS_PER_TOKEN;
|
const int thread_group_offset = lane % NUM_THREADS_PER_X;
|
||||||
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
||||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||||
|
const int shared_memory_offset = DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4);
|
||||||
|
|
||||||
__shared__ float4 q_shared[Q_SHARED_SIZE];
|
__shared__ float4 q_shared[Q_SHARED_SIZE];
|
||||||
__shared__ float red_shared_mem[2 * NUM_WARPS];
|
__shared__ float red_shared_mem[2 * NUM_WARPS];
|
||||||
extern __shared__ char shared_mem[];
|
extern __shared__ char shared_mem[];
|
||||||
float* logits = reinterpret_cast<float*>(shared_mem);
|
int* block_table_shared = reinterpret_cast<int*>(shared_mem);
|
||||||
float* out_shared_mem = reinterpret_cast<float*>(shared_mem);
|
float* logits = reinterpret_cast<float*>(shared_mem + shared_memory_offset);
|
||||||
|
float* out_shared_mem = reinterpret_cast<float*>(shared_mem + shared_memory_offset);
|
||||||
float qk_max = -FLT_MAX;
|
float qk_max = -FLT_MAX;
|
||||||
|
|
||||||
const float4* q_ptr = reinterpret_cast<const float4*>(q + seq_idx * q_stride + head_idx * HEAD_SIZE);
|
const float4* q_ptr = reinterpret_cast<const float4*>(q + seq_idx * q_stride + head_idx * HEAD_SIZE);
|
||||||
|
@ -102,32 +108,47 @@ __global__ void flash_decoding_attention_kernel(
|
||||||
for (int idx = thread_idx; idx < Q_SHARED_SIZE; idx += blockDim.x) {
|
for (int idx = thread_idx; idx < Q_SHARED_SIZE; idx += blockDim.x) {
|
||||||
q_shared[idx] = q_ptr[idx];
|
q_shared[idx] = q_ptr[idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int idx = thread_idx; idx < max_num_blocks_per_seq; idx += blockDim.x) {
|
||||||
|
block_table_shared[idx] = block_table[idx];
|
||||||
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
scalar_t* q_shared_ptr = reinterpret_cast<scalar_t*>(q_shared);
|
scalar_t* q_shared_ptr = reinterpret_cast<scalar_t*>(q_shared);
|
||||||
// each warp access a whole block
|
// each warp access a whole block
|
||||||
|
|
||||||
|
K_vec q_vecs[NUM_VECS_PER_THREAD];
|
||||||
|
#pragma unroll
|
||||||
|
for (int idx = lane, i = 0; idx < NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN; idx += WARP_SIZE, i += 1) {
|
||||||
|
const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS;
|
||||||
|
const int offset1 = idx % NUM_THREADS_PER_X;
|
||||||
|
q_vecs[i] = *reinterpret_cast<K_vec*>(q_shared_ptr + offset0 * x + offset1 * VEC_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
|
for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
|
||||||
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
const int64_t physical_block_number = static_cast<int64_t>(block_table_shared[block_idx]);
|
||||||
|
|
||||||
|
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int idx = lane; idx < BLOCK_SIZE * NUM_VECS_PER_TOKEN; idx += WARP_STRIDE) {
|
for (int i = 0; i < BLOCK_SIZE; i += NUM_ROWS_PER_ROUNDS) {
|
||||||
const int token_idx = block_idx * BLOCK_SIZE + idx / NUM_VECS_PER_TOKEN;
|
|
||||||
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||||
+ kv_head_idx * kv_head_stride
|
+ kv_head_idx * kv_head_stride
|
||||||
+ idx * VEC_SIZE;
|
+ i * x;
|
||||||
|
|
||||||
K_vec k_vecs[NUM_ROUNDS_PER_TOKEN];
|
|
||||||
K_vec q_vecs[NUM_ROUNDS_PER_TOKEN];
|
|
||||||
|
|
||||||
// we must calculate at least one row of hidden vectors
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
for (int idx = lane, j = 0; idx < NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN; idx += WARP_SIZE, j += 1) {
|
||||||
k_vecs[i] = (reinterpret_cast<const K_vec*>(k_ptr))[i * WARP_SIZE];
|
const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS;
|
||||||
q_vecs[i] = (reinterpret_cast<K_vec*>(q_shared_ptr))[(idx + i * WARP_SIZE) % NUM_VECS_PER_TOKEN];
|
const int offset1 = (idx / NUM_THREADS_PER_X) % NUM_ROWS_PER_ROUNDS;
|
||||||
|
const int offset2 = idx % NUM_THREADS_PER_X;
|
||||||
|
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset0 * BLOCK_SIZE * x + offset1 * x + offset2 * VEC_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
float qk = scale * Qk_dot<scalar_t, NUM_THREADS_PER_TOKEN>::dot(q_vecs, k_vecs);
|
float qk = scale * Qk_dot<scalar_t, NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X, NUM_THREADS_PER_X>::dot(q_vecs, k_vecs);
|
||||||
|
|
||||||
if (thread_group_offset == 0) {
|
if (thread_group_offset == 0 && lane < NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X) {
|
||||||
|
const int token_idx = block_idx * BLOCK_SIZE + i * NUM_ROWS_PER_ROUNDS + lane / NUM_THREADS_PER_X;
|
||||||
const bool mask = token_idx >= context_len;
|
const bool mask = token_idx >= context_len;
|
||||||
logits[token_idx] = mask ? 0.f : qk;
|
logits[token_idx] = mask ? 0.f : qk;
|
||||||
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
||||||
|
@ -136,7 +157,7 @@ __global__ void flash_decoding_attention_kernel(
|
||||||
}
|
}
|
||||||
|
|
||||||
// there exists a __syncthreads within this function
|
// there exists a __syncthreads within this function
|
||||||
qk_max = block_max<NUM_WARPS, NUM_THREADS_PER_TOKEN>(red_shared_mem, qk_max);
|
qk_max = block_max<NUM_WARPS, NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X, NUM_THREADS_PER_X>(red_shared_mem, qk_max);
|
||||||
|
|
||||||
// Get the sum of the exp values.
|
// Get the sum of the exp values.
|
||||||
float exp_sum = 0.f;
|
float exp_sum = 0.f;
|
||||||
|
@ -162,7 +183,7 @@ __global__ void flash_decoding_attention_kernel(
|
||||||
V_vec zero_value;
|
V_vec zero_value;
|
||||||
zero(zero_value);
|
zero(zero_value);
|
||||||
for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
|
for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
|
||||||
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
const int64_t physical_block_number = static_cast<int64_t>(block_table_shared[block_idx]);
|
||||||
scalar_t logit;
|
scalar_t logit;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -241,7 +262,7 @@ template<
|
||||||
void flash_decoding_attention_v1_launcher(
|
void flash_decoding_attention_v1_launcher(
|
||||||
torch::Tensor& out, // [num_tokens, num_heads, head_size]
|
torch::Tensor& out, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& query, // [num_tokens, num_heads, head_size]
|
torch::Tensor& query, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
torch::Tensor& key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||||
torch::Tensor& context_lens, // [num_tokens]
|
torch::Tensor& context_lens, // [num_tokens]
|
||||||
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
|
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
|
||||||
|
@ -266,7 +287,7 @@ void flash_decoding_attention_v1_launcher(
|
||||||
int logits_size = padded_max_context_len * sizeof(float);
|
int logits_size = padded_max_context_len * sizeof(float);
|
||||||
int outputs_size = (NUM_WARPS / 2) * NUM_THREADS_PER_TOKEN * VEC_SIZE * sizeof(float);
|
int outputs_size = (NUM_WARPS / 2) * NUM_THREADS_PER_TOKEN * VEC_SIZE * sizeof(float);
|
||||||
// Keep that in sync with the logic here!
|
// Keep that in sync with the logic here!
|
||||||
int shared_mem_size = std::max(logits_size, outputs_size);
|
int shared_mem_size = std::max(logits_size, outputs_size) + DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4);
|
||||||
|
|
||||||
dim3 grid(num_heads, num_tokens, 1);
|
dim3 grid(num_heads, num_tokens, 1);
|
||||||
dim3 block(NUM_THREADS);
|
dim3 block(NUM_THREADS);
|
||||||
|
@ -323,7 +344,7 @@ void flash_decoding_attention_v1_launcher(
|
||||||
void flash_decoding_attention(
|
void flash_decoding_attention(
|
||||||
torch::Tensor& out, // [num_tokens, num_heads, head_size]
|
torch::Tensor& out, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& query, // [num_tokens, num_heads, head_size]
|
torch::Tensor& query, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
torch::Tensor& key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||||
torch::Tensor& context_lens, // [num_tokens]
|
torch::Tensor& context_lens, // [num_tokens]
|
||||||
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
|
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
|
||||||
|
|
|
@ -287,7 +287,7 @@ void rms_layernorm(
|
||||||
RMSNORM_LAUNCHER(8, block);
|
RMSNORM_LAUNCHER(8, block);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
AT_ERROR("unroll_factor must be 1, 2, 4 or 8");
|
AT_ERROR("unroll_factor must be 1, 2, 3, 4 or 8");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -334,7 +334,7 @@ void fused_add_rms_layernorm(
|
||||||
FUSED_ADD_RMSNORM_LAUNCHER(8, block);
|
FUSED_ADD_RMSNORM_LAUNCHER(8, block);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
AT_ERROR("unroll_factor must be 1, 2, 4 or 8");
|
AT_ERROR("unroll_factor must be 1, 2, 3, 4 or 8");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,7 +62,7 @@ void flash_decoding_attention(
|
||||||
torch::Tensor& out, // [num_tokens, num_heads, head_size]
|
torch::Tensor& out, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& query, // [num_tokens, num_heads, head_size]
|
torch::Tensor& query, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor&
|
torch::Tensor&
|
||||||
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
torch::Tensor&
|
torch::Tensor&
|
||||||
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||||
torch::Tensor& context_lens, // [num_tokens]
|
torch::Tensor& context_lens, // [num_tokens]
|
||||||
|
|
|
@ -12,7 +12,7 @@ inference_ops = InferenceOpsLoader().load()
|
||||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||||
convert_kv_unpad_to_padded,
|
convert_kv_unpad_to_padded,
|
||||||
create_attention_mask,
|
create_attention_mask,
|
||||||
generate_caches_and_block_tables_v2,
|
generate_caches_and_block_tables_v3,
|
||||||
generate_caches_and_block_tables_vllm,
|
generate_caches_and_block_tables_vllm,
|
||||||
torch_attn_ref,
|
torch_attn_ref,
|
||||||
)
|
)
|
||||||
|
@ -77,7 +77,7 @@ def test_flash_decoding_attention(
|
||||||
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
|
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
|
||||||
)
|
)
|
||||||
|
|
||||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
|
||||||
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
|
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -150,6 +150,50 @@ def mock_alloc_block_table_and_kvcache_v2(
|
||||||
return block_tables
|
return block_tables
|
||||||
|
|
||||||
|
|
||||||
|
def mock_alloc_block_table_and_kvcache_v3(
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
v_cache: torch.Tensor,
|
||||||
|
context_lengths: torch.Tensor,
|
||||||
|
num_seqs: int,
|
||||||
|
max_num_blocks_per_seq: int,
|
||||||
|
block_size: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
|
||||||
|
block_id = 0
|
||||||
|
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
|
||||||
|
num_tokens_processed = 0
|
||||||
|
|
||||||
|
_, num_kv_heads, head_dim = k.shape
|
||||||
|
|
||||||
|
x = 16 // torch.tensor([], dtype=k.dtype).element_size()
|
||||||
|
|
||||||
|
for i, seq_len in enumerate(context_lengths.tolist()):
|
||||||
|
right_bound = (seq_len + block_size - 1) // block_size # open bound
|
||||||
|
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
|
||||||
|
# Manually fill kv caches by copying from k and v
|
||||||
|
for i in range(right_bound):
|
||||||
|
if i == right_bound - 1:
|
||||||
|
allocated_locs = seq_len % block_size or block_size
|
||||||
|
else:
|
||||||
|
allocated_locs = block_size
|
||||||
|
# [block_size, num_kv_heads, head_dim/x, x]->[num_kv_heads, head_dim/x, block_size,x]
|
||||||
|
k_block = (
|
||||||
|
k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :]
|
||||||
|
.reshape(allocated_locs, num_kv_heads, head_dim // x, x)
|
||||||
|
.permute(1, 2, 0, 3)
|
||||||
|
)
|
||||||
|
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
|
||||||
|
k_cache[block_id, :, :, :allocated_locs, :] = k_block
|
||||||
|
v_cache[block_id, :, :allocated_locs, :] = v_block
|
||||||
|
|
||||||
|
num_tokens_processed += allocated_locs
|
||||||
|
block_id += 1
|
||||||
|
|
||||||
|
return block_tables
|
||||||
|
|
||||||
|
|
||||||
def mock_alloc_block_table_and_kvcache_vllm(
|
def mock_alloc_block_table_and_kvcache_vllm(
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
|
@ -251,6 +295,26 @@ def generate_caches_and_block_tables_v2(
|
||||||
return k_cache, v_cache, block_tables
|
return k_cache, v_cache, block_tables
|
||||||
|
|
||||||
|
|
||||||
|
def generate_caches_and_block_tables_v3(
|
||||||
|
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
|
||||||
|
) -> Tuple[torch.Tensor, ...]:
|
||||||
|
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
|
||||||
|
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
|
||||||
|
_, num_kv_heads, head_dim = k_unpad.shape
|
||||||
|
|
||||||
|
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||||
|
|
||||||
|
k_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x)
|
||||||
|
v_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
|
||||||
|
k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=device)
|
||||||
|
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device=device)
|
||||||
|
# Mock allocation on block tables as well as blocked kv caches
|
||||||
|
block_tables = mock_alloc_block_table_and_kvcache_v3(
|
||||||
|
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
|
||||||
|
)
|
||||||
|
return k_cache, v_cache, block_tables
|
||||||
|
|
||||||
|
|
||||||
def generate_caches_and_block_tables_vllm(
|
def generate_caches_and_block_tables_vllm(
|
||||||
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
|
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
|
||||||
) -> Tuple[torch.Tensor, ...]:
|
) -> Tuple[torch.Tensor, ...]:
|
||||||
|
|
Loading…
Reference in New Issue