2024-04-18 08:45:07 +00:00
|
|
|
/*This code adapted from vllm:
|
|
|
|
* https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cu
|
2024-04-25 06:24:02 +00:00
|
|
|
*/
|
2024-04-18 08:45:07 +00:00
|
|
|
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
|
|
#include <torch/extension.h>
|
|
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
|
|
|
2024-04-24 06:17:54 +00:00
|
|
|
#include "common/micros.h"
|
2024-04-18 08:45:07 +00:00
|
|
|
#include "funcs/cast_functor.h"
|
|
|
|
#include "funcs/ternary_functor.h"
|
|
|
|
#include "funcs/binary_functor.h"
|
2024-04-24 06:17:54 +00:00
|
|
|
#include "common/vec_type_traits.h"
|
2024-04-18 08:45:07 +00:00
|
|
|
#include "attention/attention_utils.h"
|
|
|
|
|
|
|
|
#define WARP_SIZE 32
|
|
|
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
|
|
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
|
|
|
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
|
|
|
// 2^n => 2^n, 2^n-d => 2^(n-1)
|
|
|
|
#define ROUND_DOWN_HIGHEST_POWER_OF_TWO(x) (nextHighestPowerOf2((x - (x + 1) / 2 + 1)))
|
|
|
|
|
|
|
|
// a bit magic, you can ask chatgpt for help
|
|
|
|
// 2^n => 2^n, 2^n-d => 2^n
|
|
|
|
constexpr unsigned int nextHighestPowerOf2(unsigned int v) {
|
|
|
|
v--;
|
|
|
|
v |= v >> 1;
|
|
|
|
v |= v >> 2;
|
|
|
|
v |= v >> 4;
|
|
|
|
v |= v >> 8;
|
|
|
|
v |= v >> 16;
|
|
|
|
v++;
|
|
|
|
return v;
|
|
|
|
}
|
|
|
|
|
2024-04-26 11:40:37 +00:00
|
|
|
template <typename T>
|
|
|
|
inline __device__ void zero(T& dst) {
|
|
|
|
constexpr int WORDS = sizeof(T) / 4;
|
|
|
|
union {
|
|
|
|
T raw;
|
|
|
|
uint32_t words[WORDS];
|
|
|
|
} tmp;
|
|
|
|
|
|
|
|
#pragma unroll
|
|
|
|
for (int ii = 0; ii < WORDS; ii++) {
|
|
|
|
tmp.words[ii] = 0u;
|
|
|
|
}
|
|
|
|
dst = tmp.raw;
|
|
|
|
}
|
|
|
|
|
2024-04-24 06:17:54 +00:00
|
|
|
using colossalAI::funcs::BinaryOpType;
|
|
|
|
using colossalAI::funcs::CastFunctor;
|
|
|
|
using colossalAI::funcs::TernaryOpFunctor;
|
|
|
|
using colossalAI::funcs::TernaryOpType;
|
|
|
|
using colossalAI::common::VecTypeTrait;
|
|
|
|
using colossalAI::common::FloatVecTypeTrait;
|
2024-04-18 08:45:07 +00:00
|
|
|
using namespace colossalAI::cuda::attention;
|
|
|
|
|
|
|
|
|
|
|
|
// We only support head size of { 64, 128, 256 }
|
|
|
|
// models like Phi-2, whose head size is 80, is not supported right now
|
|
|
|
template<typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS>
|
|
|
|
__global__ void flash_decoding_attention_kernel(
|
|
|
|
scalar_t* __restrict__ out, // [num_tokens, num_heads, head_size]
|
|
|
|
const scalar_t* __restrict__ q, // [num_tokens, num_heads, head_size]
|
2024-04-25 06:24:02 +00:00
|
|
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
2024-04-18 08:45:07 +00:00
|
|
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
|
|
|
const int* __restrict__ context_lens, // [num_tokens]
|
|
|
|
const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq]
|
2024-04-30 07:52:23 +00:00
|
|
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
2024-04-18 08:45:07 +00:00
|
|
|
const int max_seq_len,
|
|
|
|
const int num_kv_heads,
|
|
|
|
const float scale,
|
|
|
|
const int max_num_blocks_per_seq,
|
|
|
|
const int q_stride, // num_heads * head_size
|
|
|
|
const int kv_block_stride,
|
|
|
|
const int kv_head_stride) {
|
|
|
|
const int seq_idx = blockIdx.y;
|
|
|
|
const int head_idx = blockIdx.x;
|
|
|
|
const int thread_idx = threadIdx.x;
|
|
|
|
const int lane = thread_idx % WARP_SIZE;
|
|
|
|
const int warp_idx = thread_idx / WARP_SIZE;
|
|
|
|
const int num_heads = gridDim.x;
|
|
|
|
const int num_queries_per_kv = num_heads / num_kv_heads;
|
|
|
|
const int kv_head_idx = head_idx / num_queries_per_kv;
|
|
|
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
2024-04-25 06:24:02 +00:00
|
|
|
constexpr int x = sizeof(float4) / sizeof(scalar_t);
|
|
|
|
constexpr int Q_SHARED_SIZE = HEAD_SIZE / x;
|
2024-04-18 08:45:07 +00:00
|
|
|
// here thread_group does not determine the number of threads responsible for a key
|
|
|
|
// but only the VEC_SIZE of each thread
|
|
|
|
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
2024-04-25 06:24:02 +00:00
|
|
|
constexpr int VEC_SIZE = MIN(ROUND_DOWN_HIGHEST_POWER_OF_TWO((HEAD_SIZE / THREAD_GROUP_SIZE)), x);
|
2024-04-18 08:45:07 +00:00
|
|
|
constexpr int NUM_VECS_PER_TOKEN = HEAD_SIZE / VEC_SIZE;
|
|
|
|
constexpr int NUM_THREADS_PER_TOKEN = MIN(NUM_VECS_PER_TOKEN, WARP_SIZE);
|
|
|
|
constexpr int NUM_ROUNDS_PER_TOKEN = NUM_VECS_PER_TOKEN / NUM_THREADS_PER_TOKEN;
|
|
|
|
constexpr int WARP_STRIDE = WARP_SIZE * NUM_ROUNDS_PER_TOKEN;
|
2024-04-25 06:24:02 +00:00
|
|
|
constexpr int NUM_THREADS_PER_X = x / VEC_SIZE;
|
|
|
|
constexpr int NUM_ROWS_PER_ROUNDS = MIN(WARP_SIZE / NUM_THREADS_PER_X, BLOCK_SIZE);
|
|
|
|
constexpr int NUM_VECS_PER_THREAD = NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN / WARP_SIZE;
|
2024-04-18 08:45:07 +00:00
|
|
|
|
2024-04-26 11:40:37 +00:00
|
|
|
using KVecT = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
|
|
|
|
using VVecT = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
|
|
|
|
using KQuantVecT = typename VecTypeTrait<cache_t, VEC_SIZE>::Type;
|
|
|
|
using VQuantVecT = typename VecTypeTrait<cache_t, VEC_SIZE>::Type;
|
|
|
|
using LVecT = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
|
|
|
|
using FloatVecT = typename FloatVecTypeTrait<LVecT>::Type;
|
2024-04-18 08:45:07 +00:00
|
|
|
|
|
|
|
const int context_len = context_lens[seq_idx];
|
2024-04-30 07:52:23 +00:00
|
|
|
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
2024-04-25 06:24:02 +00:00
|
|
|
const int thread_group_offset = lane % NUM_THREADS_PER_X;
|
2024-04-18 08:45:07 +00:00
|
|
|
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
|
|
|
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
2024-04-25 06:24:02 +00:00
|
|
|
const int shared_memory_offset = DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4);
|
2024-04-18 08:45:07 +00:00
|
|
|
|
|
|
|
__shared__ float4 q_shared[Q_SHARED_SIZE];
|
|
|
|
__shared__ float red_shared_mem[2 * NUM_WARPS];
|
|
|
|
extern __shared__ char shared_mem[];
|
2024-04-25 06:24:02 +00:00
|
|
|
int* block_table_shared = reinterpret_cast<int*>(shared_mem);
|
|
|
|
float* logits = reinterpret_cast<float*>(shared_mem + shared_memory_offset);
|
|
|
|
float* out_shared_mem = reinterpret_cast<float*>(shared_mem + shared_memory_offset);
|
2024-04-18 08:45:07 +00:00
|
|
|
float qk_max = -FLT_MAX;
|
|
|
|
|
|
|
|
const float4* q_ptr = reinterpret_cast<const float4*>(q + seq_idx * q_stride + head_idx * HEAD_SIZE);
|
|
|
|
#pragma unroll
|
|
|
|
for (int idx = thread_idx; idx < Q_SHARED_SIZE; idx += blockDim.x) {
|
|
|
|
q_shared[idx] = q_ptr[idx];
|
|
|
|
}
|
2024-04-25 06:24:02 +00:00
|
|
|
|
|
|
|
#pragma unroll
|
|
|
|
for (int idx = thread_idx; idx < max_num_blocks_per_seq; idx += blockDim.x) {
|
|
|
|
block_table_shared[idx] = block_table[idx];
|
|
|
|
}
|
|
|
|
|
2024-04-18 08:45:07 +00:00
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
scalar_t* q_shared_ptr = reinterpret_cast<scalar_t*>(q_shared);
|
|
|
|
// each warp access a whole block
|
2024-04-25 06:24:02 +00:00
|
|
|
|
2024-04-26 11:40:37 +00:00
|
|
|
KVecT q_vecs[NUM_VECS_PER_THREAD];
|
2024-04-25 06:24:02 +00:00
|
|
|
#pragma unroll
|
|
|
|
for (int idx = lane, i = 0; idx < NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN; idx += WARP_SIZE, i += 1) {
|
|
|
|
const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS;
|
|
|
|
const int offset1 = idx % NUM_THREADS_PER_X;
|
2024-04-26 11:40:37 +00:00
|
|
|
q_vecs[i] = *reinterpret_cast<KVecT*>(q_shared_ptr + offset0 * x + offset1 * VEC_SIZE);
|
2024-04-25 06:24:02 +00:00
|
|
|
}
|
|
|
|
|
2024-04-18 08:45:07 +00:00
|
|
|
for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
|
2024-04-25 06:24:02 +00:00
|
|
|
const int64_t physical_block_number = static_cast<int64_t>(block_table_shared[block_idx]);
|
|
|
|
|
2024-04-26 11:40:37 +00:00
|
|
|
KVecT k_vecs[NUM_VECS_PER_THREAD];
|
2024-04-25 06:24:02 +00:00
|
|
|
|
2024-04-18 08:45:07 +00:00
|
|
|
#pragma unroll
|
2024-04-25 06:24:02 +00:00
|
|
|
for (int i = 0; i < BLOCK_SIZE; i += NUM_ROWS_PER_ROUNDS) {
|
2024-04-18 08:45:07 +00:00
|
|
|
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
|
|
|
+ kv_head_idx * kv_head_stride
|
2024-04-25 06:24:02 +00:00
|
|
|
+ i * x;
|
2024-04-18 08:45:07 +00:00
|
|
|
#pragma unroll
|
2024-04-25 06:24:02 +00:00
|
|
|
for (int idx = lane, j = 0; idx < NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN; idx += WARP_SIZE, j += 1) {
|
|
|
|
const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS;
|
|
|
|
const int offset1 = (idx / NUM_THREADS_PER_X) % NUM_ROWS_PER_ROUNDS;
|
|
|
|
const int offset2 = idx % NUM_THREADS_PER_X;
|
2024-04-26 11:40:37 +00:00
|
|
|
k_vecs[j] = CastFunctor<KQuantVecT, KVecT>()(*reinterpret_cast<const KQuantVecT*>(k_ptr + offset0 * BLOCK_SIZE * x + offset1 * x + offset2 * VEC_SIZE));
|
2024-04-18 08:45:07 +00:00
|
|
|
}
|
|
|
|
|
2024-04-25 06:24:02 +00:00
|
|
|
float qk = scale * Qk_dot<scalar_t, NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X, NUM_THREADS_PER_X>::dot(q_vecs, k_vecs);
|
2024-04-18 08:45:07 +00:00
|
|
|
|
2024-04-25 06:24:02 +00:00
|
|
|
if (thread_group_offset == 0 && lane < NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X) {
|
|
|
|
const int token_idx = block_idx * BLOCK_SIZE + i * NUM_ROWS_PER_ROUNDS + lane / NUM_THREADS_PER_X;
|
2024-04-30 07:52:23 +00:00
|
|
|
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
|
2024-04-18 08:45:07 +00:00
|
|
|
const bool mask = token_idx >= context_len;
|
|
|
|
logits[token_idx] = mask ? 0.f : qk;
|
|
|
|
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// there exists a __syncthreads within this function
|
2024-04-25 06:24:02 +00:00
|
|
|
qk_max = block_max<NUM_WARPS, NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X, NUM_THREADS_PER_X>(red_shared_mem, qk_max);
|
2024-04-18 08:45:07 +00:00
|
|
|
|
|
|
|
// Get the sum of the exp values.
|
|
|
|
float exp_sum = 0.f;
|
|
|
|
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
|
|
|
|
float val = __expf(logits[i] - qk_max);
|
|
|
|
logits[i] = val;
|
|
|
|
exp_sum += val;
|
|
|
|
}
|
|
|
|
|
|
|
|
exp_sum = block_sum<NUM_WARPS>(&red_shared_mem[NUM_WARPS], exp_sum);
|
|
|
|
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
|
|
|
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
|
|
|
|
logits[i] *= inv_sum;
|
|
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
|
2024-04-26 11:40:37 +00:00
|
|
|
FloatVecT accs[NUM_ROUNDS_PER_TOKEN];
|
2024-04-18 08:45:07 +00:00
|
|
|
#pragma unroll
|
|
|
|
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
|
|
|
zero(accs[i]);
|
|
|
|
}
|
|
|
|
|
2024-04-26 11:40:37 +00:00
|
|
|
VVecT zero_value;
|
2024-04-18 08:45:07 +00:00
|
|
|
zero(zero_value);
|
|
|
|
for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
|
2024-04-25 06:24:02 +00:00
|
|
|
const int64_t physical_block_number = static_cast<int64_t>(block_table_shared[block_idx]);
|
2024-04-18 08:45:07 +00:00
|
|
|
scalar_t logit;
|
|
|
|
|
|
|
|
#pragma unroll
|
|
|
|
for (int idx = lane; idx < BLOCK_SIZE * NUM_VECS_PER_TOKEN; idx += WARP_STRIDE) {
|
|
|
|
const int token_idx = block_idx * BLOCK_SIZE + idx / NUM_VECS_PER_TOKEN;
|
|
|
|
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
|
|
|
+ kv_head_idx * kv_head_stride
|
|
|
|
+ idx * VEC_SIZE;
|
|
|
|
|
2024-04-26 11:40:37 +00:00
|
|
|
VVecT v_vecs[NUM_ROUNDS_PER_TOKEN];
|
2024-04-18 08:45:07 +00:00
|
|
|
|
|
|
|
#pragma unroll
|
|
|
|
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
2024-04-26 11:40:37 +00:00
|
|
|
v_vecs[i] = CastFunctor<VQuantVecT, VVecT>()(*((reinterpret_cast<const VQuantVecT*>(v_ptr) + i * WARP_SIZE)));
|
2024-04-18 08:45:07 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if (token_idx >= context_len) {
|
|
|
|
#pragma unroll
|
|
|
|
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
|
|
|
v_vecs[i] = zero_value;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
logit = CastFunctor<float, scalar_t>()(logits[token_idx]);
|
|
|
|
#pragma unroll
|
|
|
|
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
2024-04-26 11:40:37 +00:00
|
|
|
accs[i] = TernaryOpFunctor<scalar_t, VVecT, FloatVecT, TernaryOpType::kFma>()(logit, v_vecs[i], accs[i]);
|
2024-04-18 08:45:07 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// must insert a sync since both logits and out_shared_mem occupy the same buffer space
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
#pragma unroll
|
|
|
|
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
2024-04-26 11:40:37 +00:00
|
|
|
block_sum<FloatVecT, NUM_WARPS, NUM_THREADS_PER_TOKEN, VEC_SIZE>(out_shared_mem, accs[i]);
|
2024-04-18 08:45:07 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
scalar_t* out_ptr = out + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
2024-04-26 11:40:37 +00:00
|
|
|
LVecT out_reg;
|
2024-04-18 08:45:07 +00:00
|
|
|
#pragma unroll
|
|
|
|
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
|
|
|
if (thread_idx < NUM_THREADS_PER_TOKEN) {
|
2024-04-26 11:40:37 +00:00
|
|
|
out_reg = CastFunctor<FloatVecT, LVecT>()(accs[i]);
|
|
|
|
(reinterpret_cast<LVecT*>(out_ptr))[thread_idx + i * NUM_THREADS_PER_TOKEN] = out_reg;
|
2024-04-18 08:45:07 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#define LAUNCH_FLASH_DECODING_ATTENTION_V1(HEAD_SIZE) \
|
|
|
|
cudaFuncSetAttribute( \
|
|
|
|
((void*)flash_decoding_attention_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>), \
|
|
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
|
|
|
|
flash_decoding_attention_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
|
|
|
<<<grid, block, shared_mem_size, stream>>>( \
|
|
|
|
reinterpret_cast<T*>(out.data_ptr()), \
|
|
|
|
reinterpret_cast<T*>(query.data_ptr()), \
|
|
|
|
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
|
|
|
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
|
|
|
context_lens.data_ptr<int>(), \
|
|
|
|
block_tables.data_ptr<int>(), \
|
2024-04-30 07:52:23 +00:00
|
|
|
alibi_slopes_ptr, \
|
2024-04-18 08:45:07 +00:00
|
|
|
max_context_len, \
|
|
|
|
num_kv_heads, \
|
|
|
|
scale, \
|
|
|
|
max_num_blocks_per_seq, \
|
|
|
|
q_stride, \
|
|
|
|
kv_block_stride, \
|
|
|
|
kv_head_stride);
|
|
|
|
|
|
|
|
template<
|
|
|
|
typename T,
|
|
|
|
typename CACHE_T,
|
|
|
|
int BLOCK_SIZE,
|
|
|
|
int NUM_THREADS = 128>
|
|
|
|
void flash_decoding_attention_v1_launcher(
|
|
|
|
torch::Tensor& out, // [num_tokens, num_heads, head_size]
|
|
|
|
torch::Tensor& query, // [num_tokens, num_heads, head_size]
|
2024-04-25 06:24:02 +00:00
|
|
|
torch::Tensor& key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
2024-04-18 08:45:07 +00:00
|
|
|
torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
|
|
|
torch::Tensor& context_lens, // [num_tokens]
|
|
|
|
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
|
|
|
|
int max_context_len,
|
2024-04-30 07:52:23 +00:00
|
|
|
float scale,
|
|
|
|
const c10::optional<torch::Tensor>& alibi_slopes) {
|
2024-04-18 08:45:07 +00:00
|
|
|
int num_tokens = query.size(0);
|
|
|
|
int num_heads = query.size(1);
|
|
|
|
int head_size = query.size(2);
|
|
|
|
int max_num_blocks_per_seq = block_tables.size(1);
|
|
|
|
int q_stride = query.stride(0);
|
|
|
|
int num_kv_heads = key_cache.size(1);
|
|
|
|
int kv_block_stride = key_cache.stride(0);
|
|
|
|
int kv_head_stride = key_cache.stride(1);
|
|
|
|
|
|
|
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
|
|
|
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
|
|
|
const int VEC_SIZE = MIN(ROUND_DOWN_HIGHEST_POWER_OF_TWO((head_size / THREAD_GROUP_SIZE)), sizeof(float4) / sizeof(T));
|
|
|
|
const int NUM_VECS_PER_TOKEN = head_size / VEC_SIZE;
|
|
|
|
const int NUM_THREADS_PER_TOKEN = MIN(NUM_VECS_PER_TOKEN, WARP_SIZE);
|
|
|
|
|
|
|
|
int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
|
|
|
|
int logits_size = padded_max_context_len * sizeof(float);
|
|
|
|
int outputs_size = (NUM_WARPS / 2) * NUM_THREADS_PER_TOKEN * VEC_SIZE * sizeof(float);
|
|
|
|
// Keep that in sync with the logic here!
|
2024-04-25 06:24:02 +00:00
|
|
|
int shared_mem_size = std::max(logits_size, outputs_size) + DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4);
|
2024-04-18 08:45:07 +00:00
|
|
|
|
2024-04-30 07:52:23 +00:00
|
|
|
const float* alibi_slopes_ptr = alibi_slopes ?
|
|
|
|
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
|
|
|
: nullptr;
|
|
|
|
|
2024-04-18 08:45:07 +00:00
|
|
|
dim3 grid(num_heads, num_tokens, 1);
|
|
|
|
dim3 block(NUM_THREADS);
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
|
|
|
|
switch (head_size) {
|
|
|
|
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
|
|
|
// head sizes that we use in the model.
|
|
|
|
case 64:
|
|
|
|
LAUNCH_FLASH_DECODING_ATTENTION_V1(64);
|
|
|
|
break;
|
|
|
|
case 128:
|
|
|
|
LAUNCH_FLASH_DECODING_ATTENTION_V1(128);
|
|
|
|
break;
|
|
|
|
case 256:
|
|
|
|
LAUNCH_FLASH_DECODING_ATTENTION_V1(256);
|
|
|
|
break;
|
|
|
|
default:
|
|
|
|
AT_ERROR("head size must be 64, 128, 256");
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE) \
|
|
|
|
flash_decoding_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE>( \
|
|
|
|
out, \
|
|
|
|
query, \
|
|
|
|
key_cache, \
|
|
|
|
value_cache, \
|
|
|
|
context_lens, \
|
|
|
|
block_tables, \
|
|
|
|
max_context_len, \
|
2024-04-30 07:52:23 +00:00
|
|
|
scale, \
|
|
|
|
alibi_slopes);
|
2024-04-18 08:45:07 +00:00
|
|
|
|
|
|
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
|
|
|
// 1, 2, 4, 64, 128, 256.
|
|
|
|
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T) \
|
|
|
|
switch (block_size) { \
|
|
|
|
case 8: \
|
|
|
|
CALL_V1_LAUNCHER(T, CACHE_T, 8); \
|
|
|
|
break; \
|
|
|
|
case 16: \
|
|
|
|
CALL_V1_LAUNCHER(T, CACHE_T, 16); \
|
|
|
|
break; \
|
|
|
|
case 32: \
|
|
|
|
CALL_V1_LAUNCHER(T, CACHE_T, 32); \
|
|
|
|
break; \
|
|
|
|
default: \
|
|
|
|
AT_ERROR("block size must be 8, 16, 32"); \
|
|
|
|
break; \
|
|
|
|
}
|
|
|
|
|
|
|
|
void flash_decoding_attention(
|
|
|
|
torch::Tensor& out, // [num_tokens, num_heads, head_size]
|
|
|
|
torch::Tensor& query, // [num_tokens, num_heads, head_size]
|
2024-04-25 06:24:02 +00:00
|
|
|
torch::Tensor& key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
2024-04-18 08:45:07 +00:00
|
|
|
torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
|
|
|
torch::Tensor& context_lens, // [num_tokens]
|
|
|
|
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
|
|
|
|
int block_size,
|
|
|
|
int max_context_len,
|
|
|
|
torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
|
|
|
|
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
|
2024-04-30 07:52:23 +00:00
|
|
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
2024-04-18 08:45:07 +00:00
|
|
|
float scale) {
|
2024-04-26 11:40:37 +00:00
|
|
|
|
|
|
|
if(key_cache.scalar_type() == at::ScalarType::Byte)
|
|
|
|
{
|
|
|
|
switch (query.scalar_type()) {
|
|
|
|
case at::ScalarType::Float:
|
|
|
|
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t);
|
|
|
|
break;
|
|
|
|
case at::ScalarType::Half:
|
|
|
|
CALL_V1_LAUNCHER_BLOCK_SIZE(half, uint8_t);
|
|
|
|
break;
|
|
|
|
case at::ScalarType::BFloat16:
|
|
|
|
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
switch (query.scalar_type()) {
|
|
|
|
case at::ScalarType::Float:
|
|
|
|
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float);
|
|
|
|
break;
|
|
|
|
case at::ScalarType::Half:
|
|
|
|
CALL_V1_LAUNCHER_BLOCK_SIZE(half, half);
|
|
|
|
break;
|
|
|
|
case at::ScalarType::BFloat16:
|
|
|
|
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16);
|
|
|
|
break;
|
|
|
|
}
|
2024-04-18 08:45:07 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#undef LAUNCH_FLASH_DECODING_ATTENTION_V1
|
|
|
|
#undef CALL_V1_LAUNCHER
|
|
|
|
#undef CALL_V1_LAUNCHER_BLOCK_SIZE
|