mirror of https://github.com/hpcaitech/ColossalAI
[Inference/Kernel] Add Paged Decoding kernel, sequence split within the same thread block (#5531)
* feat flash decoding for paged attention * refactor flashdecodingattention * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5611/head
parent
56b222eff8
commit
be396ad6cc
|
@ -437,6 +437,19 @@ class NopadLlamaAttention(LlamaAttention):
|
|||
block_tables,
|
||||
high_precision,
|
||||
)
|
||||
# inference_ops.flash_decoding_attention(
|
||||
# attn_output,
|
||||
# query_states,
|
||||
# k_cache,
|
||||
# v_cache,
|
||||
# sequence_lengths,
|
||||
# block_tables,
|
||||
# block_size,
|
||||
# kv_seq_len,
|
||||
# fd_inter_tensor.mid_output,
|
||||
# fd_inter_tensor.mid_output_lse,
|
||||
# sm_scale,
|
||||
# )
|
||||
else:
|
||||
if is_verifier:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
|
|
|
@ -4,8 +4,8 @@ from colossalai.kernel.triton import flash_decoding_attention
|
|||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
convert_kv_unpad_to_padded,
|
||||
create_attention_mask,
|
||||
generate_caches_and_block_tables_v2,
|
||||
prepare_padding_mask,
|
||||
torch_attn_ref,
|
||||
)
|
||||
from tests.test_infer.test_ops.triton.test_decoding_attn import prepare_data
|
||||
|
@ -67,9 +67,18 @@ def bench_kernel(
|
|||
if provider == "torch":
|
||||
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_seq_len_in_b)
|
||||
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_seq_len_in_b)
|
||||
torch_padding_mask = prepare_padding_mask(kv_lengths, bsz, max_seq_len_in_b, q.device)
|
||||
torch_padding_mask = create_attention_mask(kv_lengths, bsz, Q_LEN, max_seq_len_in_b, q.device)
|
||||
fn = lambda: torch_attn_ref(
|
||||
q, k_torch, v_torch, torch_padding_mask, bsz, 1, max_seq_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
|
||||
q,
|
||||
k_torch,
|
||||
v_torch,
|
||||
torch_padding_mask,
|
||||
bsz,
|
||||
Q_LEN,
|
||||
max_seq_len_in_b,
|
||||
num_attn_heads,
|
||||
num_kv_heads,
|
||||
HEAD_DIM,
|
||||
)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
||||
if provider == "triton":
|
||||
|
|
|
@ -0,0 +1,173 @@
|
|||
import torch
|
||||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import flash_decoding_attention
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
generate_caches_and_block_tables_v2,
|
||||
generate_caches_and_block_tables_vllm,
|
||||
)
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
except ImportError:
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
# Triton benchmark plot attributions
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["MAX_NUM_BLOCKS_PER_SEQ"],
|
||||
x_vals=[2**i for i in range(3, 8)],
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"vllm_paged_decoding_attention",
|
||||
"triton_flash_decoding_attention",
|
||||
"cuda_flash_decoding_attention",
|
||||
],
|
||||
line_names=[
|
||||
"vllm_paged_decoding_attention",
|
||||
"triton_flash_decoding_attention",
|
||||
"cuda_flash_decoding_attention",
|
||||
],
|
||||
styles=[("red", "-"), ("blue", "-"), ("yellow", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"FlashDecodingAttention benchmarking results",
|
||||
args={"BATCH_SIZE": 16, "BLOCK_SIZE": 32, "HEAD_SIZE": 128, "KV_GROUP_NUM": 2},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def prepare_data(
|
||||
BATCH_SIZE: int,
|
||||
HEAD_SIZE: int,
|
||||
NUM_ATTN_HEADS: int,
|
||||
NUM_KV_HEADS: int,
|
||||
MAX_SEQ_LEN: int,
|
||||
dtype=torch.float16,
|
||||
device="cuda",
|
||||
):
|
||||
# Use the provided maximum sequence length for each sequence when testing with teh same context length,
|
||||
# otherwise generate random context lengths.
|
||||
# returns
|
||||
# q [BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE]
|
||||
# k_unpad/v_unpad [num_tokens, NUM_KV_HEADS, HEAD_SIZE]
|
||||
kv_lengths = torch.randint(low=1, high=MAX_SEQ_LEN, size=(BATCH_SIZE,), dtype=torch.int32, device=device)
|
||||
num_tokens = torch.sum(kv_lengths).item()
|
||||
|
||||
q_size = (BATCH_SIZE, 1, NUM_ATTN_HEADS, HEAD_SIZE)
|
||||
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2)
|
||||
kv_size = (num_tokens, 2 * NUM_KV_HEADS, HEAD_SIZE)
|
||||
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
k_unpad, v_unpad = torch.split(kv_unpad, [NUM_KV_HEADS, NUM_KV_HEADS], dim=-2)
|
||||
|
||||
return q, k_unpad, v_unpad, kv_lengths
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark_flash_decoding_attention(
|
||||
provider: str,
|
||||
BATCH_SIZE: int,
|
||||
BLOCK_SIZE: int,
|
||||
MAX_NUM_BLOCKS_PER_SEQ: int,
|
||||
HEAD_SIZE: int,
|
||||
KV_GROUP_NUM: int,
|
||||
):
|
||||
try:
|
||||
from vllm._C import ops as vllm_ops
|
||||
except ImportError:
|
||||
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm")
|
||||
|
||||
warmup = 10
|
||||
rep = 1000
|
||||
|
||||
dtype = torch.float16
|
||||
|
||||
NUM_ATTN_HEADS = 16
|
||||
|
||||
NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM
|
||||
assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads."
|
||||
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
|
||||
device = get_current_device()
|
||||
|
||||
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
|
||||
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
|
||||
)
|
||||
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
|
||||
)
|
||||
|
||||
vllm_k_cache, vllm_v_cache, _ = generate_caches_and_block_tables_vllm(
|
||||
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
|
||||
)
|
||||
|
||||
block_tables = block_tables.to(device=device)
|
||||
max_seq_len_across_batch = kv_seq_lengths.max().item()
|
||||
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
|
||||
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
|
||||
sm_scale = 1.0 / (HEAD_SIZE**0.5)
|
||||
|
||||
mid_output = torch.empty(
|
||||
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
|
||||
)
|
||||
mid_output_lse = torch.empty(
|
||||
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
if provider == "vllm_paged_decoding_attention":
|
||||
alibi_slopes = None
|
||||
fn = lambda: vllm_ops.paged_attention_v1(
|
||||
output,
|
||||
q.squeeze(2),
|
||||
vllm_k_cache,
|
||||
vllm_v_cache,
|
||||
NUM_KV_HEADS,
|
||||
sm_scale,
|
||||
block_tables,
|
||||
kv_seq_lengths,
|
||||
BLOCK_SIZE,
|
||||
max_seq_len_across_batch,
|
||||
alibi_slopes,
|
||||
"auto",
|
||||
)
|
||||
elif provider == "triton_flash_decoding_attention":
|
||||
fn = lambda: flash_decoding_attention(
|
||||
q.squeeze(2),
|
||||
k_cache,
|
||||
v_cache,
|
||||
kv_seq_lengths,
|
||||
block_tables,
|
||||
BLOCK_SIZE,
|
||||
max_seq_len_across_batch,
|
||||
output,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
kv_group_num=KV_GROUP_NUM,
|
||||
) # [bsz, 1, num_heads, head_dim]
|
||||
elif provider == "cuda_flash_decoding_attention":
|
||||
fn = lambda: inference_ops.flash_decoding_attention(
|
||||
output,
|
||||
q.squeeze(2),
|
||||
k_cache,
|
||||
v_cache,
|
||||
kv_seq_lengths,
|
||||
block_tables,
|
||||
BLOCK_SIZE,
|
||||
max_seq_len_across_batch,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
sm_scale,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Undefined provider.")
|
||||
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
|
||||
return ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark_flash_decoding_attention.run(save_path=".", print_data=True)
|
|
@ -0,0 +1,206 @@
|
|||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* Copyright (c) 2024, The Colossal-AI team.
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <float.h>
|
||||
|
||||
#include "../funcs/binary_functor.h"
|
||||
#include "../funcs/cast_functor.h"
|
||||
#include "../funcs/ternary_functor.h"
|
||||
#include "../funcs/unary_functor.h"
|
||||
#include "../utils/vec_type_traits.h"
|
||||
|
||||
namespace colossalAI {
|
||||
namespace cuda {
|
||||
namespace attention {
|
||||
|
||||
using colossalAI::cuda::funcs::BinaryOpFunctor;
|
||||
using colossalAI::cuda::funcs::BinaryOpType;
|
||||
using colossalAI::cuda::funcs::TernaryOpFunctor;
|
||||
using colossalAI::cuda::funcs::TernaryOpType;
|
||||
using colossalAI::cuda::funcs::UnaryOpFunctor;
|
||||
using colossalAI::cuda::funcs::UnaryOpType;
|
||||
using colossalAI::cuda::utils::FloatVecTypeTrait;
|
||||
|
||||
#define WARP_SIZE 32
|
||||
#define VEC_SIZE_8 8
|
||||
|
||||
#define SHFL_XOR_SYNC(var, lane_mask) \
|
||||
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
||||
#define SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
|
||||
|
||||
// Q*K^T operation.
|
||||
template <int NUM_THREADS_PER_TOKEN, typename VecT, int N>
|
||||
inline __device__ float qk_dot_(const VecT (&q)[N], const VecT (&k)[N]) {
|
||||
using A_vec = typename FloatVecTypeTrait<VecT>::Type;
|
||||
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
||||
BinaryOpFunctor<VecT, VecT, A_vec, BinaryOpType::kMul> mul_vect;
|
||||
UnaryOpFunctor<A_vec, float, UnaryOpType::kSum> sum_vect;
|
||||
TernaryOpFunctor<VecT, VecT, A_vec, TernaryOpType::kFma> fma;
|
||||
|
||||
A_vec qk_vec = mul_vect(q[0], k[0]);
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < N; ii++) {
|
||||
qk_vec = fma(q[ii], k[ii], qk_vec);
|
||||
}
|
||||
|
||||
// Finalize the reduction across lanes.
|
||||
float qk = sum_vect(qk_vec);
|
||||
#pragma unroll
|
||||
for (int mask = (NUM_THREADS_PER_TOKEN >> 1); mask > 0; mask >>= 1) {
|
||||
qk += SHFL_XOR_SYNC(qk, mask);
|
||||
}
|
||||
return qk;
|
||||
}
|
||||
|
||||
template <typename T, int NUM_THREADS_PER_TOKEN>
|
||||
struct Qk_dot {
|
||||
template <typename VecT, int N>
|
||||
static inline __device__ float dot(const VecT (&q)[N], const VecT (&k)[N]) {
|
||||
return qk_dot_<NUM_THREADS_PER_TOKEN>(q, k);
|
||||
}
|
||||
};
|
||||
|
||||
template <int NUM_WARPS, int NUM_THREADS_PER_TOKEN>
|
||||
inline __device__ float block_max(float* red_smem, float max) {
|
||||
int warp = threadIdx.x >> 5;
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
|
||||
// Perform reduction across the threads in the same warp to get the max value
|
||||
// for each warp, the 1st out of NUM_THREADS_PER_TOKEN thread already has the
|
||||
// max value among every NUM_THREADS_PER_TOKEN threads.
|
||||
#pragma unroll
|
||||
for (int mask = (WARP_SIZE >> 1); mask >= NUM_THREADS_PER_TOKEN; mask >>= 1) {
|
||||
max = fmaxf(max, SHFL_XOR_SYNC(max, mask));
|
||||
}
|
||||
|
||||
if (lane == 0) red_smem[warp] = max;
|
||||
__syncthreads();
|
||||
|
||||
// The warps compute the final maxs.
|
||||
max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||
|
||||
// Parallel reduction of all tokens from the same sequence inside the warp.
|
||||
#pragma unroll
|
||||
for (int mask = (NUM_WARPS >> 1); mask > 0; mask >>= 1) {
|
||||
max = fmaxf(max, SHFL_XOR_SYNC(max, mask));
|
||||
}
|
||||
|
||||
// Broadcast to other threads.
|
||||
return SHFL_SYNC(max, 0);
|
||||
}
|
||||
|
||||
// here we need another block_sum instead of using block_reduce
|
||||
// since we need manage shared memory in a explicit way
|
||||
template <int NUM_WARPS>
|
||||
inline __device__ float block_sum(float* red_smem, float sum) {
|
||||
int warp = threadIdx.x >> 5;
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
|
||||
// Compute the sum per warp.
|
||||
#pragma unroll
|
||||
for (int mask = (WARP_SIZE >> 1); mask > 0; mask >>= 1) {
|
||||
sum += SHFL_XOR_SYNC(sum, mask);
|
||||
}
|
||||
|
||||
if (lane == 0) red_smem[warp] = sum;
|
||||
__syncthreads();
|
||||
|
||||
if (lane < NUM_WARPS) {
|
||||
sum = red_smem[lane];
|
||||
}
|
||||
|
||||
// Parallel reduction of all tokens from the same sequence inside the warp.
|
||||
#pragma unroll
|
||||
for (int mask = (NUM_WARPS >> 1); mask > 0; mask >>= 1) {
|
||||
sum += SHFL_XOR_SYNC(sum, mask);
|
||||
}
|
||||
|
||||
// Broadcast to other threads.
|
||||
return SHFL_SYNC(sum, 0);
|
||||
}
|
||||
|
||||
// here VecT is a vector of float, whose size is N
|
||||
template <typename VecT, int NUM_WARPS, int NUM_THREADS_PER_GROUP, int N>
|
||||
inline __device__ void block_sum(float* red_smem, VecT& acc) {
|
||||
float* acc_ptr = reinterpret_cast<float*>(&acc);
|
||||
int warp = threadIdx.x >> 5;
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
#pragma unroll
|
||||
for (int mask = (WARP_SIZE >> 1); mask >= NUM_THREADS_PER_GROUP;
|
||||
mask >>= 1) {
|
||||
acc_ptr[i] += SHFL_XOR_SYNC(acc_ptr[i], mask);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int limit = NUM_WARPS; limit > 1; limit >>= 1) {
|
||||
int mid = limit >> 1;
|
||||
if (warp >= mid && warp < limit) {
|
||||
float* dst = red_smem + (warp - mid) * N * NUM_THREADS_PER_GROUP;
|
||||
if (lane < NUM_THREADS_PER_GROUP) {
|
||||
if constexpr (N == VEC_SIZE_8) {
|
||||
VecT* vdst = &((reinterpret_cast<VecT*>(dst))[lane]);
|
||||
(reinterpret_cast<float4*>(vdst))[0] =
|
||||
(reinterpret_cast<float4*>(acc_ptr))[0];
|
||||
(reinterpret_cast<float4*>(vdst))[1] =
|
||||
(reinterpret_cast<float4*>(acc_ptr))[1];
|
||||
} else {
|
||||
(reinterpret_cast<VecT*>(dst))[lane] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (warp < mid) {
|
||||
float* src = red_smem + warp * N * NUM_THREADS_PER_GROUP;
|
||||
VecT src_reg;
|
||||
if (lane < NUM_THREADS_PER_GROUP) {
|
||||
float* src_ptr = reinterpret_cast<float*>(&src_reg);
|
||||
if constexpr (N == VEC_SIZE_8) {
|
||||
VecT* vsrc = &((reinterpret_cast<VecT*>(src))[lane]);
|
||||
(reinterpret_cast<float4*>(src_ptr))[0] =
|
||||
(reinterpret_cast<float4*>(vsrc))[0];
|
||||
(reinterpret_cast<float4*>(src_ptr))[1] =
|
||||
(reinterpret_cast<float4*>(vsrc))[1];
|
||||
} else {
|
||||
src_reg = (reinterpret_cast<VecT*>(src))[lane];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < N; j++) {
|
||||
acc_ptr[j] += src_ptr[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
#undef SHFL_SYNC
|
||||
#undef SHFL_XOR_SYNC
|
||||
|
||||
} // namespace attention
|
||||
} // namespace cuda
|
||||
} // namespace colossalAI
|
|
@ -0,0 +1,353 @@
|
|||
/*This code adapted from vllm:
|
||||
* https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cu
|
||||
* with different kvcache layout. */
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "../common/micros.h"
|
||||
#include "funcs/cast_functor.h"
|
||||
#include "funcs/ternary_functor.h"
|
||||
#include "funcs/binary_functor.h"
|
||||
#include "utils/vec_type_traits.h"
|
||||
#include "attention/attention_utils.h"
|
||||
|
||||
#define WARP_SIZE 32
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
// 2^n => 2^n, 2^n-d => 2^(n-1)
|
||||
#define ROUND_DOWN_HIGHEST_POWER_OF_TWO(x) (nextHighestPowerOf2((x - (x + 1) / 2 + 1)))
|
||||
|
||||
// a bit magic, you can ask chatgpt for help
|
||||
// 2^n => 2^n, 2^n-d => 2^n
|
||||
constexpr unsigned int nextHighestPowerOf2(unsigned int v) {
|
||||
v--;
|
||||
v |= v >> 1;
|
||||
v |= v >> 2;
|
||||
v |= v >> 4;
|
||||
v |= v >> 8;
|
||||
v |= v >> 16;
|
||||
v++;
|
||||
return v;
|
||||
}
|
||||
|
||||
using colossalAI::cuda::funcs::BinaryOpType;
|
||||
using colossalAI::cuda::funcs::CastFunctor;
|
||||
using colossalAI::cuda::funcs::TernaryOpFunctor;
|
||||
using colossalAI::cuda::funcs::TernaryOpType;
|
||||
using colossalAI::cuda::funcs::zero;
|
||||
using colossalAI::cuda::utils::VecTypeTrait;
|
||||
using colossalAI::cuda::utils::FloatVecTypeTrait;
|
||||
using namespace colossalAI::cuda::attention;
|
||||
|
||||
|
||||
// We only support head size of { 64, 128, 256 }
|
||||
// models like Phi-2, whose head size is 80, is not supported right now
|
||||
template<typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS>
|
||||
__global__ void flash_decoding_attention_kernel(
|
||||
scalar_t* __restrict__ out, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_tokens, num_heads, head_size]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
const int* __restrict__ context_lens, // [num_tokens]
|
||||
const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq]
|
||||
const int max_seq_len,
|
||||
const int num_kv_heads,
|
||||
const float scale,
|
||||
const int max_num_blocks_per_seq,
|
||||
const int q_stride, // num_heads * head_size
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride) {
|
||||
const int seq_idx = blockIdx.y;
|
||||
const int head_idx = blockIdx.x;
|
||||
const int thread_idx = threadIdx.x;
|
||||
const int lane = thread_idx % WARP_SIZE;
|
||||
const int warp_idx = thread_idx / WARP_SIZE;
|
||||
const int num_heads = gridDim.x;
|
||||
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||
const int kv_head_idx = head_idx / num_queries_per_kv;
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
constexpr int Q_SHARED_SIZE = (HEAD_SIZE * sizeof(scalar_t)) / sizeof(float4);
|
||||
// here thread_group does not determine the number of threads responsible for a key
|
||||
// but only the VEC_SIZE of each thread
|
||||
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
constexpr int VEC_SIZE = MIN(ROUND_DOWN_HIGHEST_POWER_OF_TWO((HEAD_SIZE / THREAD_GROUP_SIZE)), sizeof(float4) / sizeof(scalar_t));
|
||||
constexpr int NUM_VECS_PER_TOKEN = HEAD_SIZE / VEC_SIZE;
|
||||
constexpr int NUM_THREADS_PER_TOKEN = MIN(NUM_VECS_PER_TOKEN, WARP_SIZE);
|
||||
constexpr int NUM_ROUNDS_PER_TOKEN = NUM_VECS_PER_TOKEN / NUM_THREADS_PER_TOKEN;
|
||||
constexpr int WARP_STRIDE = WARP_SIZE * NUM_ROUNDS_PER_TOKEN;
|
||||
|
||||
using K_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
|
||||
using V_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
|
||||
using L_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
|
||||
using Float_vec = typename FloatVecTypeTrait<L_vec>::Type;
|
||||
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int thread_group_offset = thread_idx % NUM_THREADS_PER_TOKEN;
|
||||
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
|
||||
__shared__ float4 q_shared[Q_SHARED_SIZE];
|
||||
__shared__ float red_shared_mem[2 * NUM_WARPS];
|
||||
extern __shared__ char shared_mem[];
|
||||
float* logits = reinterpret_cast<float*>(shared_mem);
|
||||
float* out_shared_mem = reinterpret_cast<float*>(shared_mem);
|
||||
float qk_max = -FLT_MAX;
|
||||
|
||||
const float4* q_ptr = reinterpret_cast<const float4*>(q + seq_idx * q_stride + head_idx * HEAD_SIZE);
|
||||
#pragma unroll
|
||||
for (int idx = thread_idx; idx < Q_SHARED_SIZE; idx += blockDim.x) {
|
||||
q_shared[idx] = q_ptr[idx];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
scalar_t* q_shared_ptr = reinterpret_cast<scalar_t*>(q_shared);
|
||||
// each warp access a whole block
|
||||
for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
|
||||
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
||||
#pragma unroll
|
||||
for (int idx = lane; idx < BLOCK_SIZE * NUM_VECS_PER_TOKEN; idx += WARP_STRIDE) {
|
||||
const int token_idx = block_idx * BLOCK_SIZE + idx / NUM_VECS_PER_TOKEN;
|
||||
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride
|
||||
+ idx * VEC_SIZE;
|
||||
|
||||
K_vec k_vecs[NUM_ROUNDS_PER_TOKEN];
|
||||
K_vec q_vecs[NUM_ROUNDS_PER_TOKEN];
|
||||
|
||||
// we must calculate at least one row of hidden vectors
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
||||
k_vecs[i] = (reinterpret_cast<const K_vec*>(k_ptr))[i * WARP_SIZE];
|
||||
q_vecs[i] = (reinterpret_cast<K_vec*>(q_shared_ptr))[(idx + i * WARP_SIZE) % NUM_VECS_PER_TOKEN];
|
||||
}
|
||||
|
||||
float qk = scale * Qk_dot<scalar_t, NUM_THREADS_PER_TOKEN>::dot(q_vecs, k_vecs);
|
||||
|
||||
if (thread_group_offset == 0) {
|
||||
const bool mask = token_idx >= context_len;
|
||||
logits[token_idx] = mask ? 0.f : qk;
|
||||
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// there exists a __syncthreads within this function
|
||||
qk_max = block_max<NUM_WARPS, NUM_THREADS_PER_TOKEN>(red_shared_mem, qk_max);
|
||||
|
||||
// Get the sum of the exp values.
|
||||
float exp_sum = 0.f;
|
||||
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
|
||||
float val = __expf(logits[i] - qk_max);
|
||||
logits[i] = val;
|
||||
exp_sum += val;
|
||||
}
|
||||
|
||||
exp_sum = block_sum<NUM_WARPS>(&red_shared_mem[NUM_WARPS], exp_sum);
|
||||
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
||||
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
|
||||
logits[i] *= inv_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
Float_vec accs[NUM_ROUNDS_PER_TOKEN];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
||||
zero(accs[i]);
|
||||
}
|
||||
|
||||
V_vec zero_value;
|
||||
zero(zero_value);
|
||||
for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
|
||||
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
||||
scalar_t logit;
|
||||
|
||||
#pragma unroll
|
||||
for (int idx = lane; idx < BLOCK_SIZE * NUM_VECS_PER_TOKEN; idx += WARP_STRIDE) {
|
||||
const int token_idx = block_idx * BLOCK_SIZE + idx / NUM_VECS_PER_TOKEN;
|
||||
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride
|
||||
+ idx * VEC_SIZE;
|
||||
|
||||
V_vec v_vecs[NUM_ROUNDS_PER_TOKEN];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
||||
v_vecs[i] = (reinterpret_cast<const V_vec*>(v_ptr))[i * WARP_SIZE];
|
||||
}
|
||||
|
||||
if (token_idx >= context_len) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
||||
v_vecs[i] = zero_value;
|
||||
}
|
||||
}
|
||||
|
||||
logit = CastFunctor<float, scalar_t>()(logits[token_idx]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
||||
accs[i] = TernaryOpFunctor<scalar_t, V_vec, Float_vec, TernaryOpType::kFma>()(logit, v_vecs[i], accs[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// must insert a sync since both logits and out_shared_mem occupy the same buffer space
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
||||
block_sum<Float_vec, NUM_WARPS, NUM_THREADS_PER_TOKEN, VEC_SIZE>(out_shared_mem, accs[i]);
|
||||
}
|
||||
|
||||
scalar_t* out_ptr = out + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
L_vec out_reg;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
||||
if (thread_idx < NUM_THREADS_PER_TOKEN) {
|
||||
out_reg = CastFunctor<Float_vec, L_vec>()(accs[i]);
|
||||
(reinterpret_cast<L_vec*>(out_ptr))[thread_idx + i * NUM_THREADS_PER_TOKEN] = out_reg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define LAUNCH_FLASH_DECODING_ATTENTION_V1(HEAD_SIZE) \
|
||||
cudaFuncSetAttribute( \
|
||||
((void*)flash_decoding_attention_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>), \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
|
||||
flash_decoding_attention_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
reinterpret_cast<T*>(out.data_ptr()), \
|
||||
reinterpret_cast<T*>(query.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
context_lens.data_ptr<int>(), \
|
||||
block_tables.data_ptr<int>(), \
|
||||
max_context_len, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
max_num_blocks_per_seq, \
|
||||
q_stride, \
|
||||
kv_block_stride, \
|
||||
kv_head_stride);
|
||||
|
||||
template<
|
||||
typename T,
|
||||
typename CACHE_T,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS = 128>
|
||||
void flash_decoding_attention_v1_launcher(
|
||||
torch::Tensor& out, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& query, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
torch::Tensor& context_lens, // [num_tokens]
|
||||
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
|
||||
int max_context_len,
|
||||
float scale) {
|
||||
int num_tokens = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
int q_stride = query.stride(0);
|
||||
int num_kv_heads = key_cache.size(1);
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
const int VEC_SIZE = MIN(ROUND_DOWN_HIGHEST_POWER_OF_TWO((head_size / THREAD_GROUP_SIZE)), sizeof(float4) / sizeof(T));
|
||||
const int NUM_VECS_PER_TOKEN = head_size / VEC_SIZE;
|
||||
const int NUM_THREADS_PER_TOKEN = MIN(NUM_VECS_PER_TOKEN, WARP_SIZE);
|
||||
|
||||
int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
|
||||
int logits_size = padded_max_context_len * sizeof(float);
|
||||
int outputs_size = (NUM_WARPS / 2) * NUM_THREADS_PER_TOKEN * VEC_SIZE * sizeof(float);
|
||||
// Keep that in sync with the logic here!
|
||||
int shared_mem_size = std::max(logits_size, outputs_size);
|
||||
|
||||
dim3 grid(num_heads, num_tokens, 1);
|
||||
dim3 block(NUM_THREADS);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
switch (head_size) {
|
||||
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||
// head sizes that we use in the model.
|
||||
case 64:
|
||||
LAUNCH_FLASH_DECODING_ATTENTION_V1(64);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_FLASH_DECODING_ATTENTION_V1(128);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_FLASH_DECODING_ATTENTION_V1(256);
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("head size must be 64, 128, 256");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE) \
|
||||
flash_decoding_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE>( \
|
||||
out, \
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
context_lens, \
|
||||
block_tables, \
|
||||
max_context_len, \
|
||||
scale);
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T) \
|
||||
switch (block_size) { \
|
||||
case 8: \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 8); \
|
||||
break; \
|
||||
case 16: \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 16); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 32); \
|
||||
break; \
|
||||
default: \
|
||||
AT_ERROR("block size must be 8, 16, 32"); \
|
||||
break; \
|
||||
}
|
||||
|
||||
void flash_decoding_attention(
|
||||
torch::Tensor& out, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& query, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
torch::Tensor& context_lens, // [num_tokens]
|
||||
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
|
||||
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
|
||||
float scale) {
|
||||
switch (query.scalar_type()) {
|
||||
case at::ScalarType::Float:
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float);
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(half, half);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16);
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("Unsupported data type: ", toString(query.scalar_type()));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#undef LAUNCH_FLASH_DECODING_ATTENTION_V1
|
||||
#undef CALL_V1_LAUNCHER
|
||||
#undef CALL_V1_LAUNCHER_BLOCK_SIZE
|
|
@ -8,11 +8,20 @@
|
|||
#include <functional>
|
||||
|
||||
#include "../utils/micros.h"
|
||||
#include "../utils/vec_type_traits.h"
|
||||
#include "cast_functor.h"
|
||||
|
||||
namespace colossalAI {
|
||||
namespace cuda {
|
||||
namespace funcs {
|
||||
|
||||
using utils::bfloat164;
|
||||
using utils::bfloat168;
|
||||
using utils::float4_;
|
||||
using utils::float8_;
|
||||
using utils::half4;
|
||||
using utils::half8;
|
||||
|
||||
enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin };
|
||||
|
||||
// Note(LiuYang): This file provides base math operation for data type
|
||||
|
@ -22,73 +31,182 @@ enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin };
|
|||
template <typename LT, typename RT, typename RET, BinaryOpType op_type>
|
||||
struct BinaryOpFunctor;
|
||||
|
||||
#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BINARY_OP_TYPE, STMT, \
|
||||
FUNCTION_MODIFIER, ARGS...) \
|
||||
template <ARGS> \
|
||||
struct BinaryOpFunctor<T, T, T, BINARY_OP_TYPE> \
|
||||
: public std::binary_function<T, T, T> { \
|
||||
FUNCTION_MODIFIER T operator()(T lhs, T rhs) { return STMT; } \
|
||||
#define STMTS_WRAPPER(...) __VA_ARGS__
|
||||
|
||||
#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( \
|
||||
LT, RT, RET, BINARY_OP_TYPE, FUNCTION_MODIFIER, STMTS, ARGS...) \
|
||||
template <ARGS> \
|
||||
struct BinaryOpFunctor<LT, RT, RET, BINARY_OP_TYPE> \
|
||||
: public std::binary_function<LT, RT, RET> { \
|
||||
FUNCTION_MODIFIER RET operator()(LT lhs, RT rhs) STMTS \
|
||||
};
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kAdd, lhs + rhs,
|
||||
HOSTDEVICE, typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMinus, lhs - rhs,
|
||||
HOSTDEVICE, typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMul, lhs* rhs,
|
||||
HOSTDEVICE, typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kDiv, lhs / rhs,
|
||||
HOSTDEVICE, typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMax, max(lhs, rhs),
|
||||
HOSTDEVICE, typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMin, min(lhs, rhs),
|
||||
HOSTDEVICE, typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kAdd, HOSTDEVICE,
|
||||
STMTS_WRAPPER({ return lhs + rhs; }),
|
||||
typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMinus,
|
||||
HOSTDEVICE,
|
||||
STMTS_WRAPPER({ return lhs - rhs; }),
|
||||
typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMul, HOSTDEVICE,
|
||||
STMTS_WRAPPER({ return lhs * rhs; }),
|
||||
typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kDiv, HOSTDEVICE,
|
||||
STMTS_WRAPPER({ return lhs / rhs; }),
|
||||
typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMax, HOSTDEVICE,
|
||||
STMTS_WRAPPER({ return max(lhs, rhs); }),
|
||||
typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMin, HOSTDEVICE,
|
||||
STMTS_WRAPPER({ return min(lhs, rhs); }),
|
||||
typename T)
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kAdd,
|
||||
__hadd(lhs, rhs), DEVICE)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kAdd,
|
||||
__hadd2(lhs, rhs), DEVICE)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kAdd,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hadd(lhs, rhs);
|
||||
}))
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, half2, half2, BinaryOpType::kAdd,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hadd2(lhs, rhs);
|
||||
}))
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd,
|
||||
__hadd(lhs, rhs), DEVICE)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kAdd,
|
||||
__hadd2(lhs, rhs), DEVICE)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16,
|
||||
__nv_bfloat16, BinaryOpType::kAdd,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hadd(lhs, rhs);
|
||||
}))
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, __nv_bfloat162,
|
||||
__nv_bfloat162, BinaryOpType::kAdd,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hadd2(lhs, rhs);
|
||||
}))
|
||||
#else
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd,
|
||||
__float2bfloat16(__bfloat162float(lhs) +
|
||||
__bfloat162float(rhs)),
|
||||
DEVICE)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, BinaryOpType::kAdd,
|
||||
__floats2bfloat162_rn(__low2float(lhs) + __low2float(rhs),
|
||||
__high2float(lhs) + __high2float(rhs)),
|
||||
DEVICE)
|
||||
#endif
|
||||
__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, BinaryOpType::kAdd, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __float2bfloat16(__bfloat162float(lhs) + __bfloat162float(rhs));
|
||||
}))
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, __nv_bfloat162, __nv_bfloat162, BinaryOpType::kAdd, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __floats2bfloat162_rn(__low2float(lhs) + __low2float(rhs),
|
||||
__high2float(lhs) + __high2float(rhs));
|
||||
}))
|
||||
#endif /* defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 */
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kMul,
|
||||
__hmul(lhs, rhs), DEVICE)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kMul,
|
||||
__hmul2(lhs, rhs), DEVICE)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kMul,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hmul(lhs, rhs);
|
||||
}))
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, half2, half2, BinaryOpType::kMul,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hmul2(lhs, rhs);
|
||||
}))
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul,
|
||||
__hmul(lhs, rhs), DEVICE)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kMul,
|
||||
__hmul2(lhs, rhs), DEVICE)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16,
|
||||
__nv_bfloat16, BinaryOpType::kMul,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hmul(lhs, rhs);
|
||||
}))
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, __nv_bfloat162,
|
||||
__nv_bfloat162, BinaryOpType::kMul,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hmul2(lhs, rhs);
|
||||
}))
|
||||
#else
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul,
|
||||
__float2bfloat16(__bfloat162float(lhs) *
|
||||
__bfloat162float(rhs)),
|
||||
DEVICE)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, BinaryOpType::kMul,
|
||||
__floats2bfloat162_rn(__low2float(lhs) * __low2float(rhs),
|
||||
__high2float(lhs) * __high2float(rhs)),
|
||||
DEVICE)
|
||||
#endif
|
||||
__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __float2bfloat16(__bfloat162float(lhs) * __bfloat162float(rhs));
|
||||
}))
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, __nv_bfloat162, __nv_bfloat162, BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __floats2bfloat162_rn(__low2float(lhs) * __low2float(rhs),
|
||||
__high2float(lhs) * __high2float(rhs));
|
||||
}))
|
||||
#endif /* defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 */
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
float2, float2, float2, BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({ return make_float2(lhs.x * rhs.x, lhs.y * rhs.y); }))
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(float4, float4, float4,
|
||||
BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return make_float4(
|
||||
lhs.x * rhs.x, lhs.y * rhs.y,
|
||||
lhs.z * rhs.z, lhs.w * rhs.w);
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, __nv_bfloat162, float2, BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
CastFunctor<__nv_bfloat162, float2> cast;
|
||||
BinaryOpFunctor<float2, float2, float2, BinaryOpType::kMul> mul;
|
||||
float2 fa = cast(lhs);
|
||||
float2 fb = cast(rhs);
|
||||
return mul(fa, fb);
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
bfloat164, bfloat164, float4_, BinaryOpType::kMul, DEVICE, STMTS_WRAPPER({
|
||||
float4_ fc;
|
||||
BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
BinaryOpType::kMul>
|
||||
mul;
|
||||
fc.x = mul(lhs.x, rhs.x);
|
||||
fc.y = mul(lhs.y, rhs.y);
|
||||
return fc;
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
bfloat168, bfloat168, float8_, BinaryOpType::kMul, DEVICE, STMTS_WRAPPER({
|
||||
float8_ fc;
|
||||
BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
BinaryOpType::kMul>
|
||||
mul;
|
||||
fc.x = mul(lhs.x, rhs.x);
|
||||
fc.y = mul(lhs.y, rhs.y);
|
||||
fc.z = mul(lhs.z, rhs.z);
|
||||
fc.w = mul(lhs.w, rhs.w);
|
||||
return fc;
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
half2, half2, float2, BinaryOpType::kMul, DEVICE, STMTS_WRAPPER({
|
||||
CastFunctor<half2, float2> cast;
|
||||
BinaryOpFunctor<float2, float2, float2, BinaryOpType::kMul> mul;
|
||||
float2 fa = cast(lhs);
|
||||
float2 fb = cast(rhs);
|
||||
return mul(fa, fb);
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
half4, half4, float4_, BinaryOpType::kMul, DEVICE, STMTS_WRAPPER({
|
||||
float4_ fc;
|
||||
BinaryOpFunctor<half2, half2, float2, BinaryOpType::kMul> mul;
|
||||
fc.x = mul(lhs.x, rhs.x);
|
||||
fc.y = mul(lhs.y, rhs.y);
|
||||
return fc;
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
half8, half8, float8_, BinaryOpType::kMul, DEVICE, STMTS_WRAPPER({
|
||||
float8_ fc;
|
||||
BinaryOpFunctor<half2, half2, float2, BinaryOpType::kMul> mul;
|
||||
fc.x = mul(lhs.x, rhs.x);
|
||||
fc.y = mul(lhs.y, rhs.y);
|
||||
fc.z = mul(lhs.z, rhs.z);
|
||||
fc.w = mul(lhs.w, rhs.w);
|
||||
return fc;
|
||||
}))
|
||||
|
||||
#undef COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION
|
||||
|
||||
#undef STMTS_WRAPPER
|
||||
|
||||
} // namespace funcs
|
||||
} // namespace cuda
|
||||
} // namespace colossalAI
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
#include <functional>
|
||||
|
||||
#include "../utils/micros.h"
|
||||
#include "../utils/vec_type_traits.h"
|
||||
|
||||
// Note(LiuYang): This file provides base math operation for data type
|
||||
// include POD and cuda built-in type such as half and __nv_bfloat16
|
||||
|
@ -16,39 +17,150 @@ namespace colossalAI {
|
|||
namespace cuda {
|
||||
namespace funcs {
|
||||
|
||||
using utils::bfloat164;
|
||||
using utils::bfloat168;
|
||||
using utils::float4_;
|
||||
using utils::float8_;
|
||||
using utils::half4;
|
||||
using utils::half8;
|
||||
|
||||
template <typename From, typename To>
|
||||
struct CastFunctor : public std::unary_function<From, To> {
|
||||
HOSTDEVICE To operator()(From val) { return static_cast<To>(val); }
|
||||
};
|
||||
|
||||
#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, STMT, \
|
||||
#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, STMTS, \
|
||||
FUNCTION_MODIFIER) \
|
||||
template <> \
|
||||
struct CastFunctor<FROM, TO> : public std::unary_function<FROM, TO> { \
|
||||
FUNCTION_MODIFIER TO operator()(FROM val) { return STMT; } \
|
||||
FUNCTION_MODIFIER TO operator()(FROM val) STMTS \
|
||||
};
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(int2, float2, make_float2(val.x, val.y),
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
int2, float2, { return make_float2(val.x, val.y); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float, float2, { return make_float2(val, val); }, DEVICE)
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, float2, make_float2(val, val),
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half, __float2half(val), DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat16,
|
||||
__float2bfloat16(val), DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat162,
|
||||
__float2bfloat162_rn(val), DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half2, __float2half2_rn(val),
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
half2, float2, { return __half22float2(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float2, half2, { return __float22half2_rn(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float, half, { return __float2half_rn(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float, half2, { return __float2half2_rn(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
half, half2, { return __half2half2(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
half, float, { return __half2float(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float4, half4,
|
||||
{
|
||||
half4 dst;
|
||||
dst.x = __floats2half2_rn(val.x, val.y);
|
||||
dst.y = __floats2half2_rn(val.z, val.w);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float4_, half4,
|
||||
{
|
||||
half4 dst;
|
||||
dst.x = __float22half2_rn(val.x);
|
||||
dst.y = __float22half2_rn(val.y);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float8_, half8,
|
||||
{
|
||||
half8 dst;
|
||||
dst.x = __float22half2_rn(val.x);
|
||||
dst.y = __float22half2_rn(val.y);
|
||||
dst.z = __float22half2_rn(val.z);
|
||||
dst.w = __float22half2_rn(val.w);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, float, __half2float(val), DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, float,
|
||||
__bfloat162float(val), DEVICE)
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, float2, __half22float2(val), DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, half2, __float22half2_rn(val),
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, half2, __half2half2(val), DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float, __nv_bfloat162, { return __float2bfloat162_rn(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float, __nv_bfloat16, { return __float2bfloat16_rn(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float4, bfloat164,
|
||||
{
|
||||
bfloat164 dst;
|
||||
dst.x = __floats2bfloat162_rn(val.x, val.y);
|
||||
dst.y = __floats2bfloat162_rn(val.z, val.w);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, __nv_bfloat162, { return __bfloat162bfloat162(val); },
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, float2, { return __bfloat1622float2(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float2, __nv_bfloat162, { return __float22bfloat162_rn(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float4_, bfloat164,
|
||||
{
|
||||
bfloat164 dst;
|
||||
dst.x = __float22bfloat162_rn(val.x);
|
||||
dst.y = __float22bfloat162_rn(val.y);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float8_, bfloat168,
|
||||
{
|
||||
bfloat168 dst;
|
||||
dst.x = __float22bfloat162_rn(val.x);
|
||||
dst.y = __float22bfloat162_rn(val.y);
|
||||
dst.z = __float22bfloat162_rn(val.z);
|
||||
dst.w = __float22bfloat162_rn(val.w);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
#else
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, __nv_bfloat162,
|
||||
{
|
||||
__nv_bfloat162 dst;
|
||||
dst.x = val;
|
||||
dst.y = val;
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, float2,
|
||||
{ return make_float2(__low2float(val), __high2float(val)); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float2, __nv_bfloat162, { return __floats2bfloat162_rn(val.x, val.y); },
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float4_, bfloat164,
|
||||
{
|
||||
bfloat164 dst;
|
||||
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
|
||||
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float8_, bfloat168,
|
||||
{
|
||||
bfloat168 dst;
|
||||
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
|
||||
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
|
||||
dst.z = __floats2bfloat162_rn(val.z.x, val.z.y);
|
||||
dst.w = __floats2bfloat162_rn(val.w.x, val.w.y);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
#endif /* defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 */
|
||||
|
||||
#undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION
|
||||
} // namespace funcs
|
||||
|
|
|
@ -0,0 +1,212 @@
|
|||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <float.h>
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "../funcs/cast_functor.h"
|
||||
#include "../utils/micros.h"
|
||||
|
||||
namespace colossalAI {
|
||||
namespace cuda {
|
||||
namespace funcs {
|
||||
|
||||
enum class TernaryOpType { kFma = 0 };
|
||||
|
||||
template <typename LT, typename RT, typename RET, TernaryOpType op_type>
|
||||
struct TernaryOpFunctor;
|
||||
|
||||
#define STMTS_WRAPPER(...) __VA_ARGS__
|
||||
|
||||
#define COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( \
|
||||
LT, RT, RET, TERNARY_OP_TYPE, FUNCTION_MODIFIER, STMTS, ARGS...) \
|
||||
template <ARGS> \
|
||||
struct TernaryOpFunctor<LT, RT, RET, TERNARY_OP_TYPE> { \
|
||||
FUNCTION_MODIFIER RET operator()(LT a, RT b, RET c) STMTS \
|
||||
};
|
||||
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float, float, float,
|
||||
TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float d;
|
||||
d = fma(a, b, c);
|
||||
return d;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float2, float2, float2,
|
||||
TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float2 d;
|
||||
d.x = fma(a.x, b.x, c.x);
|
||||
d.y = fma(a.y, b.y, c.y);
|
||||
return d;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float, float2, float2,
|
||||
TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float2 d;
|
||||
d.x = fma(a, b.x, c.x);
|
||||
d.y = fma(a, b.y, c.y);
|
||||
return d;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float4, float4, float4,
|
||||
TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float4 d;
|
||||
d.x = fma(a.x, b.x, c.x);
|
||||
d.y = fma(a.y, b.y, c.y);
|
||||
d.z = fma(a.z, b.z, c.z);
|
||||
d.w = fma(a.w, b.w, c.w);
|
||||
return d;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float, float4, float4,
|
||||
TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float4 d;
|
||||
d.x = fma(a, b.x, c.x);
|
||||
d.y = fma(a, b.y, c.y);
|
||||
d.z = fma(a, b.z, c.z);
|
||||
d.w = fma(a, b.w, c.w);
|
||||
return d;
|
||||
}))
|
||||
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
half, half, float, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({ return __half2float(a) * __half2float(b) + c; }))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
half2, half2, float2, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
|
||||
CastFunctor<half2, float2> cast;
|
||||
TernaryOpFunctor<float2, float2, float2, TernaryOpType::kFma> fma;
|
||||
float2 fa = cast(a);
|
||||
float2 fb = cast(b);
|
||||
return fma(fa, fb, c);
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
half, half2, float2, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
|
||||
CastFunctor<half, half2> cast;
|
||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||
return fma(cast(a), b, c);
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
half4, half4, float4_, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
|
||||
float4_ fd;
|
||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||
fd.x = fma(a.x, b.x, c.x);
|
||||
fd.y = fma(a.y, b.y, c.y);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
half, half4, float4_, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
|
||||
float4_ fd;
|
||||
CastFunctor<half, half2> cast;
|
||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||
half2 s = cast(a);
|
||||
fd.x = fma(s, b.x, c.x);
|
||||
fd.y = fma(s, b.y, c.y);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
half8, half8, float8_, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
|
||||
float8_ fd;
|
||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||
fd.x = fma(a.x, b.x, c.x);
|
||||
fd.y = fma(a.y, b.y, c.y);
|
||||
fd.z = fma(a.z, b.z, c.z);
|
||||
fd.w = fma(a.w, b.w, c.w);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
half, half8, float8_, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
|
||||
float8_ fd;
|
||||
CastFunctor<half, half2> cast;
|
||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||
half2 s = cast(a);
|
||||
fd.x = fma(s, b.x, c.x);
|
||||
fd.y = fma(s, b.y, c.y);
|
||||
fd.z = fma(s, b.z, c.z);
|
||||
fd.w = fma(s, b.w, c.w);
|
||||
return fd;
|
||||
}))
|
||||
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, __nv_bfloat16, float, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({ return __bfloat162float(a) * __bfloat162float(b) + c; }))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, __nv_bfloat162, float2, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
CastFunctor<__nv_bfloat162, float2> cast;
|
||||
TernaryOpFunctor<float2, float2, float2, TernaryOpType::kFma> fma;
|
||||
float2 fa = cast(a);
|
||||
float2 fb = cast(b);
|
||||
return fma(fa, fb, c);
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, __nv_bfloat162, float2, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast;
|
||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
TernaryOpType::kFma>
|
||||
fma;
|
||||
return fma(cast(a), b, c);
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
bfloat164, bfloat164, float4_, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
|
||||
float4_ fd;
|
||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
TernaryOpType::kFma>
|
||||
fma;
|
||||
fd.x = fma(a.x, b.x, c.x);
|
||||
fd.y = fma(a.y, b.y, c.y);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, bfloat164, float4_, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float4_ fd;
|
||||
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast;
|
||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
TernaryOpType::kFma>
|
||||
fma;
|
||||
__nv_bfloat162 s = cast(a);
|
||||
fd.x = fma(s, b.x, c.x);
|
||||
fd.y = fma(s, b.y, c.y);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
bfloat168, bfloat168, float8_, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
|
||||
float8_ fd;
|
||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
TernaryOpType::kFma>
|
||||
fma;
|
||||
fd.x = fma(a.x, b.x, c.x);
|
||||
fd.y = fma(a.y, b.y, c.y);
|
||||
fd.z = fma(a.z, b.z, c.z);
|
||||
fd.w = fma(a.w, b.w, c.w);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, bfloat168, float8_, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float8_ fd;
|
||||
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast;
|
||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
TernaryOpType::kFma>
|
||||
fma;
|
||||
__nv_bfloat162 s = cast(a);
|
||||
fd.x = fma(s, b.x, c.x);
|
||||
fd.y = fma(s, b.y, c.y);
|
||||
fd.z = fma(s, b.z, c.z);
|
||||
fd.w = fma(s, b.w, c.w);
|
||||
return fd;
|
||||
}))
|
||||
|
||||
#undef COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION
|
||||
|
||||
#undef STMTS_WRAPPER
|
||||
|
||||
} // namespace funcs
|
||||
} // namespace cuda
|
||||
} // namespace colossalAI
|
|
@ -13,9 +13,24 @@ namespace colossalAI {
|
|||
namespace cuda {
|
||||
namespace funcs {
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void zero(T& dst) {
|
||||
constexpr int WORDS = sizeof(T) / 4;
|
||||
union {
|
||||
T raw;
|
||||
uint32_t words[WORDS];
|
||||
} tmp;
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < WORDS; ii++) {
|
||||
tmp.words[ii] = 0u;
|
||||
}
|
||||
dst = tmp.raw;
|
||||
}
|
||||
|
||||
// Note(LiuYang): As a retrieved table to check which operation is supported
|
||||
// already
|
||||
enum class UnaryOpType { kLog2Ceil = 0, kAbs };
|
||||
enum class UnaryOpType { kLog2Ceil = 0, kAbs, kSum };
|
||||
|
||||
// Note(LiuYang): Implementation of common and simple unary operators should be
|
||||
// placed here, otherwise, they should be placed in a new file under functors
|
||||
|
@ -42,6 +57,25 @@ COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(int, int, UnaryOpType::kLog2Ceil,
|
|||
return log2_value;
|
||||
})
|
||||
|
||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float2, float, UnaryOpType::kSum, DEVICE,
|
||||
{ return val.x + val.y; })
|
||||
|
||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float4, float, UnaryOpType::kSum, DEVICE,
|
||||
{ return val.x + val.y + val.z + val.w; })
|
||||
|
||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float4_, float, UnaryOpType::kSum, DEVICE,
|
||||
{
|
||||
return val.x.x + val.x.y + val.y.x +
|
||||
val.y.y;
|
||||
})
|
||||
|
||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float8_, float, UnaryOpType::kSum, DEVICE,
|
||||
{
|
||||
return val.x.x + val.x.y + val.y.x +
|
||||
val.y.y + val.z.x + val.z.y +
|
||||
val.w.x + val.w.y;
|
||||
})
|
||||
|
||||
#undef COLOSSAL_UARY_FUNCTOR_SPECIALIZATION
|
||||
|
||||
} // namespace funcs
|
||||
|
|
|
@ -58,6 +58,21 @@ void get_cos_and_sin(at::Tensor& cos_cache, // [max_rotary_position, head_dim]
|
|||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
int max_seq_len_in_batch, bool is_prompts);
|
||||
|
||||
void flash_decoding_attention(
|
||||
torch::Tensor& out, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& query, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor&
|
||||
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
torch::Tensor& context_lens, // [num_tokens]
|
||||
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
|
||||
int block_size, int max_context_len,
|
||||
torch::Tensor&
|
||||
tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
|
||||
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
|
||||
float scale);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
||||
"Copy the GPU memory of kvcache during the decode stage.");
|
||||
|
@ -81,4 +96,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
"In-place fused Add and RMS Normalization.");
|
||||
|
||||
m.def("get_cos_and_sin", &get_cos_and_sin, "Get cos and sin from the cache.");
|
||||
|
||||
m.def("flash_decoding_attention", &flash_decoding_attention,
|
||||
"Compute the attention between an input query and the cached "
|
||||
"keys/values using PagedAttention.");
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*This code from VLLM:
|
||||
/*This code from FasterTransformer:
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu
|
||||
* with minor changes. */
|
||||
|
||||
|
@ -20,6 +20,32 @@ using colossalAI::cuda::funcs::BinaryOpFunctor;
|
|||
using colossalAI::cuda::funcs::BinaryOpType;
|
||||
using colossalAI::cuda::utils::VecTypeTrait;
|
||||
|
||||
#define RMSNORM_LAUNCHER(UNROLL_FACTOR, THREADDIM) \
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( \
|
||||
input.element_size(), \
|
||||
input.scalar_type(), \
|
||||
"rms_layernorm_kernel", \
|
||||
rms_layernorm_kernel<scalar_t, UNROLL_FACTOR><<<grid, THREADDIM, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), \
|
||||
weight.data_ptr<scalar_t>(), \
|
||||
epsilon, \
|
||||
num_tokens, \
|
||||
hidden_size);) \
|
||||
|
||||
#define FUSED_ADD_RMSNORM_LAUNCHER(UNROLL_FACTOR, THREADDIM) \
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( \
|
||||
input.element_size(), \
|
||||
input.scalar_type(), \
|
||||
"fused_add_rms_layernorm_kernel", \
|
||||
fused_add_rms_layernorm_kernel<scalar_t, UNROLL_FACTOR><<<grid, THREADDIM, 0, stream>>>( \
|
||||
input.data_ptr<scalar_t>(), \
|
||||
residual.data_ptr<scalar_t>(), \
|
||||
weight.data_ptr<scalar_t>(), \
|
||||
epsilon, \
|
||||
num_tokens, \
|
||||
hidden_size);) \
|
||||
|
||||
// optimized for half and bf16
|
||||
template<typename scalar_t, int unroll_factor>
|
||||
__global__ void rms_layernorm_kernel(
|
||||
|
@ -234,29 +260,9 @@ void rms_layernorm(
|
|||
|
||||
if (num_tokens >= 512) {
|
||||
if (input.scalar_type() == at::ScalarType::Float) {
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);)
|
||||
RMSNORM_LAUNCHER(8, hidden_size / 8);
|
||||
} else {
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);)
|
||||
RMSNORM_LAUNCHER(4, hidden_size / 8);
|
||||
}
|
||||
} else {
|
||||
int unroll_factor = (hidden_size + block.x - 1) / block.x;
|
||||
|
@ -266,56 +272,16 @@ void rms_layernorm(
|
|||
}
|
||||
switch (unroll_factor) {
|
||||
case 1:
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);)
|
||||
RMSNORM_LAUNCHER(1, block);
|
||||
break;
|
||||
case 2:
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);)
|
||||
RMSNORM_LAUNCHER(2, block);
|
||||
break;
|
||||
case 4:
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);)
|
||||
RMSNORM_LAUNCHER(4, block);
|
||||
break;
|
||||
case 8:
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);)
|
||||
RMSNORM_LAUNCHER(8, block);
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("unroll_factor must be 1, 2, 4 or 8");
|
||||
|
@ -338,29 +304,9 @@ void fused_add_rms_layernorm(
|
|||
|
||||
if (num_tokens >= 512) {
|
||||
if (input.scalar_type() == at::ScalarType::Float) {
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
residual.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);)
|
||||
FUSED_ADD_RMSNORM_LAUNCHER(8, hidden_size / 8);
|
||||
} else {
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
residual.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);)
|
||||
FUSED_ADD_RMSNORM_LAUNCHER(4, hidden_size / 8);
|
||||
}
|
||||
} else {
|
||||
int unroll_factor = (hidden_size + block.x - 1) / block.x;
|
||||
|
@ -370,56 +316,16 @@ void fused_add_rms_layernorm(
|
|||
}
|
||||
switch (unroll_factor) {
|
||||
case 1:
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
residual.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);)
|
||||
FUSED_ADD_RMSNORM_LAUNCHER(1, block);
|
||||
break;
|
||||
case 2:
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
residual.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);)
|
||||
FUSED_ADD_RMSNORM_LAUNCHER(2, block);
|
||||
break;
|
||||
case 4:
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
residual.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);)
|
||||
FUSED_ADD_RMSNORM_LAUNCHER(4, block);
|
||||
break;
|
||||
case 8:
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
residual.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);)
|
||||
FUSED_ADD_RMSNORM_LAUNCHER(8, block);
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("unroll_factor must be 1, 2, 4 or 8");
|
||||
|
|
|
@ -11,9 +11,45 @@ namespace colossalAI {
|
|||
namespace cuda {
|
||||
namespace utils {
|
||||
|
||||
struct bfloat164 {
|
||||
__nv_bfloat162 x;
|
||||
__nv_bfloat162 y;
|
||||
};
|
||||
struct bfloat168 {
|
||||
__nv_bfloat162 x;
|
||||
__nv_bfloat162 y;
|
||||
__nv_bfloat162 z;
|
||||
__nv_bfloat162 w;
|
||||
};
|
||||
|
||||
struct half4 {
|
||||
half2 x;
|
||||
half2 y;
|
||||
};
|
||||
struct half8 {
|
||||
half2 x;
|
||||
half2 y;
|
||||
half2 z;
|
||||
half2 w;
|
||||
};
|
||||
|
||||
struct float4_ {
|
||||
float2 x;
|
||||
float2 y;
|
||||
};
|
||||
struct float8_ {
|
||||
float2 x;
|
||||
float2 y;
|
||||
float2 z;
|
||||
float2 w;
|
||||
};
|
||||
|
||||
template <typename T, int VecSize>
|
||||
struct VecTypeTrait {};
|
||||
|
||||
template <typename T>
|
||||
struct FloatVecTypeTrait {};
|
||||
|
||||
#define VEC_TYPE_TRAITS_SPECIALIZATION(T, VEC_SIZE, VECT, ARGS...) \
|
||||
template <ARGS> \
|
||||
struct VecTypeTrait<T, VEC_SIZE> { \
|
||||
|
@ -31,13 +67,36 @@ VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 4, float2)
|
|||
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 8, float4)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, float4)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, float8_)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 2, half)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 4, half2)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 8, float2)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 2, __nv_bfloat162);
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 4, bfloat164);
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 8, bfloat168);
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(half, 2, half2);
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(half, 4, half4);
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(half, 8, half8);
|
||||
|
||||
#undef VEC_TYPE_TRAITS_SPECIALIZATION
|
||||
|
||||
#define FLOATVEC_TYPE_TRAITS_SPECIALIZATION(T, FLOATT, ARGS...) \
|
||||
template <ARGS> \
|
||||
struct FloatVecTypeTrait<T> { \
|
||||
using Type = FLOATT; \
|
||||
};
|
||||
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(float2, float2)
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(float4, float4)
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat162, float2);
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(bfloat164, float4_);
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(bfloat168, float8_);
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(half2, float2);
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(half4, float4_);
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(half8, float8_);
|
||||
|
||||
#undef FLOATVEC_TYPE_TRAITS_SPECIALIZATION
|
||||
|
||||
} // namespace utils
|
||||
} // namespace cuda
|
||||
} // namespace colossalAI
|
||||
|
|
|
@ -17,6 +17,7 @@ class InferenceOpsCudaExtension(_CudaExtension):
|
|||
"cuda/activation_kernel.cu",
|
||||
"cuda/rms_layernorm_kernel.cu",
|
||||
"cuda/get_cos_and_sin_kernel.cu",
|
||||
"cuda/flash_decoding_attention_kernel.cu",
|
||||
]
|
||||
]
|
||||
return ret
|
||||
|
|
|
@ -0,0 +1,274 @@
|
|||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
convert_kv_unpad_to_padded,
|
||||
create_attention_mask,
|
||||
generate_caches_and_block_tables_v2,
|
||||
generate_caches_and_block_tables_vllm,
|
||||
torch_attn_ref,
|
||||
)
|
||||
|
||||
q_len = 1
|
||||
|
||||
|
||||
def prepare_data(
|
||||
BATCH_SIZE: int,
|
||||
HEAD_SIZE: int,
|
||||
NUM_ATTN_HEADS: int,
|
||||
NUM_KV_HEADS: int,
|
||||
MAX_SEQ_LEN: int,
|
||||
dtype=torch.float16,
|
||||
device="cuda",
|
||||
):
|
||||
# Use the provided maximum sequence length for each sequence when testing with teh same context length,
|
||||
# otherwise generate random context lengths.
|
||||
# returns
|
||||
# q [BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE]
|
||||
# k_unpad/v_unpad [num_tokens, NUM_KV_HEADS, HEAD_SIZE]
|
||||
kv_lengths = torch.randint(low=1, high=MAX_SEQ_LEN, size=(BATCH_SIZE,), dtype=torch.int32, device=device)
|
||||
num_tokens = torch.sum(kv_lengths).item()
|
||||
|
||||
q_size = (BATCH_SIZE, q_len, NUM_ATTN_HEADS, HEAD_SIZE)
|
||||
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2)
|
||||
kv_size = (num_tokens, 2 * NUM_KV_HEADS, HEAD_SIZE)
|
||||
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
k_unpad, v_unpad = torch.split(kv_unpad, [NUM_KV_HEADS, NUM_KV_HEADS], dim=-2)
|
||||
|
||||
return q, k_unpad, v_unpad, kv_lengths
|
||||
|
||||
|
||||
def numpy_allclose(x, y, rtol, atol):
|
||||
x_numpy = x.detach().cpu().numpy()
|
||||
y_numpy = y.detach().cpu().numpy()
|
||||
|
||||
np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32])
|
||||
@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32])
|
||||
@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32])
|
||||
@pytest.mark.parametrize("HEAD_SIZE", [64, 128])
|
||||
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
|
||||
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_flash_decoding_attention(
|
||||
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype
|
||||
):
|
||||
torch.manual_seed(123)
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM
|
||||
assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads."
|
||||
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
|
||||
device = get_current_device()
|
||||
|
||||
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
|
||||
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
|
||||
)
|
||||
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
|
||||
)
|
||||
|
||||
block_tables = block_tables.to(device=device)
|
||||
max_seq_len_across_batch = kv_seq_lengths.max().item()
|
||||
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
|
||||
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
|
||||
sm_scale = 1.0 / (HEAD_SIZE**0.5)
|
||||
|
||||
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
|
||||
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
|
||||
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
|
||||
|
||||
mid_output = torch.empty(
|
||||
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
|
||||
)
|
||||
mid_output_lse = torch.empty(
|
||||
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
if dtype == torch.float16:
|
||||
rtol = 1e-3
|
||||
atol = 1e-3
|
||||
|
||||
high_precision_q = q.to(torch.float32)
|
||||
high_precision_k_torch = k_torch.to(torch.float32)
|
||||
high_precision_v_torch = v_torch.to(torch.float32)
|
||||
out_ref = torch_attn_ref(
|
||||
high_precision_q,
|
||||
high_precision_k_torch,
|
||||
high_precision_v_torch,
|
||||
torch_padding_mask,
|
||||
BATCH_SIZE,
|
||||
q_len,
|
||||
max_seq_len_across_batch,
|
||||
NUM_ATTN_HEADS,
|
||||
NUM_KV_HEADS,
|
||||
HEAD_SIZE,
|
||||
).to(torch.float16)
|
||||
|
||||
else:
|
||||
rtol = 1e-5
|
||||
atol = 1e-7
|
||||
|
||||
out_ref = torch_attn_ref(
|
||||
q,
|
||||
k_torch,
|
||||
v_torch,
|
||||
torch_padding_mask,
|
||||
BATCH_SIZE,
|
||||
q_len,
|
||||
max_seq_len_across_batch,
|
||||
NUM_ATTN_HEADS,
|
||||
NUM_KV_HEADS,
|
||||
HEAD_SIZE,
|
||||
)
|
||||
|
||||
inference_ops.flash_decoding_attention(
|
||||
output,
|
||||
q.squeeze(2),
|
||||
k_cache,
|
||||
v_cache,
|
||||
kv_seq_lengths,
|
||||
block_tables,
|
||||
BLOCK_SIZE,
|
||||
max_seq_len_across_batch,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
sm_scale,
|
||||
)
|
||||
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32])
|
||||
@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32])
|
||||
@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32])
|
||||
@pytest.mark.parametrize("HEAD_SIZE", [64, 128])
|
||||
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
|
||||
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_vllm_flash_decoding_attention(
|
||||
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype
|
||||
):
|
||||
torch.manual_seed(123)
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
try:
|
||||
from vllm._C import ops as vllm_ops
|
||||
except ImportError:
|
||||
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm")
|
||||
|
||||
NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM
|
||||
assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads."
|
||||
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
|
||||
device = get_current_device()
|
||||
|
||||
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
|
||||
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
|
||||
)
|
||||
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_vllm(
|
||||
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
|
||||
)
|
||||
|
||||
block_tables = block_tables.to(device=device)
|
||||
max_seq_len_across_batch = kv_seq_lengths.max().item()
|
||||
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
|
||||
sm_scale = 1.0 / (HEAD_SIZE**0.5)
|
||||
|
||||
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
|
||||
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
|
||||
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
|
||||
|
||||
if dtype == torch.float16:
|
||||
rtol = 1e-3
|
||||
atol = 1e-3
|
||||
|
||||
high_precision_q = q.to(torch.float32)
|
||||
high_precision_k_torch = k_torch.to(torch.float32)
|
||||
high_precision_v_torch = v_torch.to(torch.float32)
|
||||
out_ref = torch_attn_ref(
|
||||
high_precision_q,
|
||||
high_precision_k_torch,
|
||||
high_precision_v_torch,
|
||||
torch_padding_mask,
|
||||
BATCH_SIZE,
|
||||
q_len,
|
||||
max_seq_len_across_batch,
|
||||
NUM_ATTN_HEADS,
|
||||
NUM_KV_HEADS,
|
||||
HEAD_SIZE,
|
||||
).to(torch.float16)
|
||||
|
||||
else:
|
||||
rtol = 1e-5
|
||||
atol = 1e-7
|
||||
|
||||
out_ref = torch_attn_ref(
|
||||
q,
|
||||
k_torch,
|
||||
v_torch,
|
||||
torch_padding_mask,
|
||||
BATCH_SIZE,
|
||||
q_len,
|
||||
max_seq_len_across_batch,
|
||||
NUM_ATTN_HEADS,
|
||||
NUM_KV_HEADS,
|
||||
HEAD_SIZE,
|
||||
)
|
||||
|
||||
alibi_slopes = None
|
||||
|
||||
vllm_ops.paged_attention_v1(
|
||||
output,
|
||||
q.squeeze(2),
|
||||
k_cache,
|
||||
v_cache,
|
||||
NUM_KV_HEADS,
|
||||
sm_scale,
|
||||
block_tables,
|
||||
kv_seq_lengths,
|
||||
BLOCK_SIZE,
|
||||
max_seq_len_across_batch,
|
||||
alibi_slopes,
|
||||
"auto",
|
||||
)
|
||||
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
BATCH_SIZE = [1, 4, 7, 32]
|
||||
BLOCK_SIZE = [8, 16, 32]
|
||||
MAX_NUM_BLOCKS_PER_SEQ = [1, 8, 32]
|
||||
HEAD_SIZE = [64, 128]
|
||||
NUM_ATTN_HEADS = [16]
|
||||
KV_GROUP_NUM = [1, 2, 16]
|
||||
DTYPE = [torch.float16, torch.float32]
|
||||
test_combinations = list(
|
||||
product(BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, DTYPE)
|
||||
)
|
||||
for (
|
||||
batch_size,
|
||||
block_size,
|
||||
max_num_blocks_per_seq,
|
||||
head_size,
|
||||
num_attn_heads,
|
||||
kv_group_num,
|
||||
dtype,
|
||||
) in test_combinations:
|
||||
test_flash_decoding_attention(
|
||||
batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype
|
||||
)
|
|
@ -150,6 +150,51 @@ def mock_alloc_block_table_and_kvcache_v2(
|
|||
return block_tables
|
||||
|
||||
|
||||
def mock_alloc_block_table_and_kvcache_vllm(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
context_lengths: torch.Tensor,
|
||||
num_seqs: int,
|
||||
max_num_blocks_per_seq: int,
|
||||
block_size: int,
|
||||
) -> torch.Tensor:
|
||||
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
|
||||
block_id = 0
|
||||
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
|
||||
num_tokens_processed = 0
|
||||
|
||||
_, num_kv_heads, head_dim = k.shape
|
||||
|
||||
x = 16 // torch.tensor([], dtype=k.dtype).element_size()
|
||||
|
||||
for i, seq_len in enumerate(context_lengths.tolist()):
|
||||
right_bound = (seq_len + block_size - 1) // block_size # open bound
|
||||
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
|
||||
# Manually fill kv caches by copying from k and v
|
||||
for i in range(right_bound):
|
||||
if i == right_bound - 1:
|
||||
allocated_locs = seq_len % block_size or block_size
|
||||
else:
|
||||
allocated_locs = block_size
|
||||
# [block_size, num_kv_heads, head_dim/x, x]->[num_kv_heads, head_dim/x, block_size,x]
|
||||
k_block = (
|
||||
k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :]
|
||||
.reshape(allocated_locs, num_kv_heads, head_dim // x, x)
|
||||
.permute(1, 2, 0, 3)
|
||||
)
|
||||
# [block_size, num_kv_heads, head_dim]->[num_kv_heads, head_dim, block_size]
|
||||
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
|
||||
k_cache[block_id, :, :, :allocated_locs, :] = k_block
|
||||
v_cache[block_id, :, :, :allocated_locs] = v_block
|
||||
|
||||
num_tokens_processed += allocated_locs
|
||||
block_id += 1
|
||||
|
||||
return block_tables
|
||||
|
||||
|
||||
def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int) -> None:
|
||||
# Allocate 1 token on the block table for each seqs in block tables.
|
||||
# It won't change provided context_lengths.
|
||||
|
@ -206,6 +251,26 @@ def generate_caches_and_block_tables_v2(
|
|||
return k_cache, v_cache, block_tables
|
||||
|
||||
|
||||
def generate_caches_and_block_tables_vllm(
|
||||
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
|
||||
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
|
||||
_, num_kv_heads, head_dim = k_unpad.shape
|
||||
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
k_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x)
|
||||
v_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
|
||||
k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=device)
|
||||
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device=device)
|
||||
# Mock allocation on block tables as well as blocked kv caches
|
||||
block_tables = mock_alloc_block_table_and_kvcache_vllm(
|
||||
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
return k_cache, v_cache, block_tables
|
||||
|
||||
|
||||
def convert_kv_unpad_to_padded(
|
||||
k_unpad: torch.Tensor, kv_seq_lengths: torch.Tensor, bsz: int, max_seq_len: int
|
||||
) -> torch.Tensor:
|
||||
|
|
Loading…
Reference in New Issue