[Inference/kernel]Add Fused Rotary Embedding and KVCache Memcopy CUDA Kernel (#5418)

* add rotary embedding kernel

* add rotary_embedding_kernel

* add fused rotary_emb and kvcache memcopy

* add fused_rotary_emb_and_cache_kernel.cu

* add fused_rotary_emb_and_memcopy

* fix bugs in fused_rotary_emb_and_cache_kernel.cu

* fix ci bugs

* use vec memcopy and opt the  gloabl memory access

* fix code style

* fix test_rotary_embdding_unpad.py

* codes revised based on the review comments

* fix bugs about include path

* rm inline
pull/5460/head
yuehuayingxueluo 2024-03-13 17:20:03 +08:00 committed by GitHub
parent ed431de4e4
commit f366a5ea1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 928 additions and 78 deletions

View File

@ -320,8 +320,12 @@ class NopadLlamaAttention(LlamaAttention):
)
block_size = k_cache.size(-2)
if is_prompts:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
if use_cuda_kernel:
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
else:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
@ -337,9 +341,16 @@ class NopadLlamaAttention(LlamaAttention):
)
else:
if use_cuda_kernel:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
inference_ops.rotary_embedding_and_cache_copy(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
sequence_lengths,
block_tables,
)
else:
decoding_fused_rotary_embedding(

View File

@ -47,5 +47,5 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
self._cos_cached = torch.cos(freqs).to(self.dtype).cuda()
self._sin_cached = torch.sin(freqs).to(self.dtype).cuda()

View File

@ -1,8 +1,11 @@
import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token
inference_ops = InferenceOpsLoader().load()
try:
import triton # noqa
@ -16,9 +19,19 @@ configs = [
x_names=["num_tokens"],
x_vals=[2**i for i in range(4, 11)],
line_arg="provider",
line_vals=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
line_names=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
styles=[("red", "-"), ("blue", "-")],
line_vals=[
"no_fused_triton_rotary_emb_func",
"fused_triton_rotary_emb_func",
"no_fused_cuda_rotary_emb_func",
"fused_cuda_rotary_emb_func",
],
line_names=[
"no_fused_triton_rotary_emb_func",
"fused_triton_rotary_emb_func",
"no_fused_cuda_rotary_emb_func",
"fused_cuda_rotary_emb_func",
],
styles=[("red", "-"), ("blue", "-"), ("green", "-"), ("yellow", "-")],
ylabel="ms",
plot_name=f"rotary_emb-batch-{BATCH}",
args={"num_kv_heads": 16},
@ -32,7 +45,7 @@ def benchmark_rotary_emb(
num_tokens: int,
num_kv_heads: int,
):
BATCH_SIZE = 4
BATCH_SIZE = 16
SEQ_LEN = num_tokens // BATCH_SIZE
max_num_blocks_per_seq = 8
block_size = 64
@ -68,7 +81,7 @@ def benchmark_rotary_emb(
kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda")
if provider == "no_fused_rotary_emb_func":
if provider == "no_fused_triton_rotary_emb_func":
fn = lambda: [
rotary_embedding(new_q, new_k, cos, sin),
copy_kv_to_blocked_cache(
@ -77,7 +90,16 @@ def benchmark_rotary_emb(
]
elif provider == "fused_triton_rotary_emb_func":
fn = lambda: decoding_fused_rotary_embedding(
new_q, new_k, new_k, cos, sin, k_cache, k_cache, block_tables, kv_seq_lengths
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths
)
elif provider == "no_fused_cuda_rotary_emb_func":
fn = lambda: [
inference_ops.rotary_embedding(new_q, new_k, cos, sin),
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables),
]
elif provider == "fused_cuda_rotary_emb_func":
fn = lambda: inference_ops.rotary_embedding_and_cache_copy(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables
)
else:
raise ValueError("Undefined provider")

View File

@ -1,7 +1,11 @@
import torch
import triton
from vllm._C import ops
from colossalai.kernel.triton.fused_rotary_embedding import fused_rotary_embedding
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import rotary_embedding
inference_ops = InferenceOpsLoader().load()
BATCH = 16
configs = [
@ -9,9 +13,9 @@ configs = [
x_names=["num_tokens"],
x_vals=[2**i for i in range(4, 12)],
line_arg="provider",
line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"],
line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"],
styles=[("red", "-"), ("blue", "-")],
line_vals=["triton_func", "colossal_cuda_func", "vllm_cuda_func"],
line_names=["triton_func", "colossal_cuda_func", "vllm_cuda_func"],
styles=[("red", "-"), ("blue", "-"), ("yellow", "-")],
ylabel="ms",
plot_name=f"rotary_emb-batch-{BATCH}",
args={"num_kv_heads": 16},
@ -48,12 +52,19 @@ def benchmark_rotary_emb(
cos_shape = (4096, head_dim // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
lengths = torch.tensor([3, 4, 6, 7], device="cuda")
if provider == "torch_rotary_emb_func":
fn = lambda: torch_rotary_emb(q, cos[:num_tokens], sin[:num_tokens])
elif provider == "triton_rotary_emb_func":
fn = lambda: fused_rotary_embedding(q, k, cos, sin, lengths)
cos_sin = torch.stack((cos, sin), dim=1).contiguous()
positions = torch.arange(num_tokens).cuda()
if provider == "triton_func":
fn = lambda: rotary_embedding(q, k, cos, sin)
elif provider == "colossal_cuda_func":
fn = lambda: inference_ops.rotary_embedding(q, k, cos, sin)
elif provider == "vllm_cuda_func":
q = q.view(num_tokens, -1)
k = k.view(num_tokens, -1)
fn = lambda: ops.rotary_embedding(positions, q, k, head_dim, cos_sin, True)
else:
raise ValueError("Undefined provider")

View File

@ -0,0 +1,54 @@
import torch
from colossalai.kernel.triton import get_xine_cache
from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin
try:
import triton # noqa
except ImportError:
print("please install triton from https://github.com/openai/triton")
configs = [
triton.testing.Benchmark(
x_names=["max_num_tokens"],
x_vals=[2**i for i in range(6, 12)],
line_arg="provider",
line_vals=["torch_get_cos_sin", "triton_get_cos_sin"],
line_names=["torch_get_cos_sin", "triton_get_cos_sin"],
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name="Get_cos-sin_func",
args={"batch_size": 16, "head_dim": 256},
)
]
@triton.testing.perf_report(configs)
def benchmark_get_xine_cache(
provider: str,
max_num_tokens: int,
batch_size: int,
head_dim: int,
):
warmup = 10
rep = 1000
dtype = torch.float16
cos_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
sin_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
lengths = torch.randint(2, max_num_tokens, (batch_size,), device="cuda")
if provider == "torch_get_cos_sin":
fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
elif provider == "triton_get_cos_sin":
fn = lambda: get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True)
else:
raise ValueError("Undefined provider")
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if __name__ == "__main__":
benchmark_get_xine_cache.run(save_path=".", print_data=True)

View File

@ -0,0 +1,98 @@
#include <c10/macros/Macros.h>
#include <cuda_fp16.h>
#include <cfloat>
#include "string"
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 2>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*((float *)dst) = *((float *)src);
}
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 8>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*((float4 *)dst) = *((float4 *)src);
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
const c10::Half *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 2>(c10::Half *dst,
const c10::Half *src) {
*((float *)dst) = *((float *)src);
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
const c10::Half *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 8>(c10::Half *dst,
const c10::Half *src) {
*((float4 *)dst) = *((float4 *)src);
}
template <>
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<float, 2>(float *dst, const float *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <>
__device__ __inline__ void copy_vector<float, 4>(float *dst, const float *src) {
*((float4 *)dst) = *((float4 *)src);
}
template <>
__device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
// Since the maximum memory alignment length is 128 bits, we choose float4
// here.
*((float4 *)dst) = *((float4 *)src);
*((float4 *)(dst + 4)) = *((float4 *)(src + 4));
}
template <typename T>
int get_vec_size(const torch::Tensor &tensor) {
uint64_t address = reinterpret_cast<uint64_t>(tensor.data_ptr<T>());
const int max_aligned_size = 128;
const int dtype_size = sizeof(T) * 8;
const int vec_size = max_aligned_size / sizeof(T) / 8;
if (address % (dtype_size * 4) == 0) {
return std::min(4, vec_size);
} else if (address % (dtype_size * 2) == 0) {
return std::min(2, vec_size);
} else {
return 1;
}
}

View File

@ -39,6 +39,9 @@ torch::Tensor silu_and_mul(const torch::Tensor& ins)
auto ins_shape = ins.sizes().vec();
ins_shape[0] = ins_shape[0]/2;
if (ins_shape[0] == 1) {
ins_shape.erase(ins_shape.begin());
}
auto outs = torch::zeros(ins_shape,ins.options());
auto outs_shape = ins.sizes().vec();

View File

@ -1,10 +1,10 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <stdio.h>
#include "../common/vector_copy_utils.h"
#include "../common/micros.h"
template<typename scalar_t>
template<typename scalar_t, int VecSize>
__global__ void decode_kv_cache_memcpy_kernel(
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
@ -12,79 +12,146 @@ __global__ void decode_kv_cache_memcpy_kernel(
scalar_t* __restrict__ value_cache,
const int* __restrict__ sequence_lengths,
const int* __restrict__ block_tables,
const int num_heads,
const int head_size,
const int head_num,
const int head_dim,
const int block_size,
const int key_stride,
const int value_stride,
const int64_t key_stride,
const int64_t value_stride,
const int block_table_stride
)
{
const int seq_id = blockIdx.x;
const int seq_len = sequence_lengths[seq_id] - 1;
const int seq_id_in_block_table = seq_len / block_size;
const int block_offset = seq_len % block_size;
const int block_id = block_tables[seq_id * block_table_stride + seq_id_in_block_table];
const int hidden_size = num_heads * head_size;
const int block_id = block_tables[seq_id * block_table_stride + seq_len / block_size];
const int hidden_size = head_num * head_dim;
if ( block_id < 0 ) {
return ;
}
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
const int head_id = i / head_size;
const int head_offset = i % head_size;
const int key_src_id = seq_id * key_stride + i;
const int value_src_id = seq_id * value_stride + i;
const int target_src_id = block_id * hidden_size * block_size
+ head_id * block_size * head_size
+ block_offset * head_size + head_offset;
for (int i = threadIdx.x * VecSize; i < hidden_size; i += blockDim.x * VecSize) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
key_cache[target_src_id] = key[key_src_id];
value_cache[target_src_id] = value[value_src_id];
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
}
}
void decode_kv_cache_memcpy(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables) // [batch_size, max_seq_len]
template<typename scalar_t>
void apply_decode_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len]
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int head_num = key.size(1);
int head_dim = key.size(2);
int block_size = key_cache.size(2);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
int64_t key_stride = key.stride(0);
int64_t value_stride = value.stride(0);
int block_table_stride = block_tables.stride(0);
int vec_size = get_vec_size<scalar_t>(key);
if (head_dim % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
vec_size = 1;
}
int thread_nums = head_num * head_dim / vec_size;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
DISPATCH_FLOAT_HALF_AND_BFLOAT(
key.scalar_type(),
"decode_kv_cache_memcpy",
decode_kv_cache_memcpy_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
num_heads,
head_size,
block_size,
key_stride,
value_stride,
block_table_stride
);)
dim3 block(std::min(thread_nums, 512));
switch (vec_size) {
case 1:
decode_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
key_stride,
value_stride,
block_table_stride
);
break;
case 2:
decode_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
key_stride,
value_stride,
block_table_stride
);
break;
case 4:
decode_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
key_stride,
value_stride,
block_table_stride
);
break;
default:
AT_ERROR("Unsupported vectorized size ", vec_size);
break;
}
AT_CUDA_CHECK(cudaGetLastError());
}
void decode_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len]
{
DISPATCH_FLOAT_HALF_AND_BFLOAT(
key.scalar_type(),
"decode_kv_cache_memcpy",
apply_decode_kv_cache_memcpy<scalar_t>(
key,
value,
key_cache,
value_cache,
sequence_lengths,
block_tables
);)
}

View File

@ -0,0 +1,472 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "../common/vector_copy_utils.h"
#include "../common/micros.h"
template <typename scalar_t, int VecSize>
__device__ void apply_emb_rotary_compute(
scalar_t* __restrict__ src, const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr, const int64_t stride,
const int token_id, const int shard_block_size, const int half_head_dim,
const int head_num, const int head_dim) {
scalar_t x[VecSize];
scalar_t y[VecSize];
scalar_t out_x[VecSize];
scalar_t out_y[VecSize];
for (int i = threadIdx.x * VecSize; i < head_num * half_head_dim;
i += blockDim.x * VecSize) {
const int head_offset = i % half_head_dim;
const int shard_offset =
(head_offset / shard_block_size) * shard_block_size +
(head_offset % shard_block_size) / VecSize;
const int64_t addr_offset =
token_id * stride + (i / half_head_dim) * head_dim + head_offset;
copy_vector<scalar_t, VecSize>(x, src + addr_offset);
copy_vector<scalar_t, VecSize>(y, src + addr_offset + half_head_dim);
#pragma unroll
for (int j = 0; j < VecSize; j++) {
out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] -
y[j] * sin_ptr[j * 32 + shard_offset];
out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] +
x[j] * sin_ptr[j * 32 + shard_offset];
}
copy_vector<scalar_t, VecSize>(src + addr_offset, out_x);
copy_vector<scalar_t, VecSize>(src + addr_offset + half_head_dim, out_y);
}
}
template <typename scalar_t, int VecSize>
__device__ void apply_kv_memcopy(
scalar_t* __restrict__ src, scalar_t* __restrict__ cache,
const int64_t stride, const int token_id, const int block_id,
const int hidden_size, const int block_size, const int block_offset,
const int head_dim, const int half_head_dim) {
for (int i = threadIdx.x * VecSize; i < hidden_size / 2;
i += blockDim.x * VecSize) {
const int head_id = i / half_head_dim;
const int head_offset = i % half_head_dim;
const int64_t src_id = token_id * stride + head_id * head_dim + head_offset;
const int64_t target_id = block_id * hidden_size * block_size +
head_id * block_size * head_dim +
block_offset * head_dim + head_offset;
copy_vector<scalar_t, VecSize>(cache + target_id, src + src_id);
copy_vector<scalar_t, VecSize>(cache + target_id + half_head_dim,
src + src_id + half_head_dim);
}
}
template <typename scalar_t, int VecSize>
__device__ void cos_sin_memory_access(
const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin,
scalar_t* cos_ptr, scalar_t* sin_ptr, const int token_id,
const int shard_block_size, const int cos_stride, const int sin_stride,
const int half_head_dim) {
for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) {
// We assume that the value of head_dim is less than 128*128.
const int shard_offset = (i % shard_block_size) / VecSize;
const int shard_head =
(i / shard_block_size) * shard_block_size + i % VecSize * 32;
cos_ptr[shard_head + shard_offset] = cos[token_id * cos_stride + i];
sin_ptr[shard_head + shard_offset] = sin[token_id * sin_stride + i];
}
}
template <typename scalar_t, int VecSize>
__device__ void apply_k_rotary_emb_compute(
scalar_t* __restrict__ key, scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr,
const int* __restrict__ sequence_lengths,
const int* __restrict__ block_tables, const int64_t key_stride,
const int64_t value_stride, const int token_id,
const int block_table_stride, const int head_num, const int head_dim,
const int kv_head_num, const int block_size, const int half_head_dim,
const int shard_block_size) {
const int seq_len = sequence_lengths[token_id] - 1;
const int block_offset = seq_len % block_size;
const int block_id =
block_tables[token_id * block_table_stride + seq_len / block_size];
if (block_id < 0) {
return;
}
scalar_t x[VecSize];
scalar_t y[VecSize];
scalar_t out_x[VecSize];
scalar_t out_y[VecSize];
for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim;
i += blockDim.x * VecSize) {
const int head_offset = i % half_head_dim;
const int shard_offset =
(head_offset / shard_block_size) * shard_block_size +
(head_offset % shard_block_size) / VecSize;
const int64_t addr_offset =
token_id * key_stride + (i / half_head_dim) * head_dim + head_offset;
const int64_t target_id = block_id * head_num * head_dim * block_size +
(i / half_head_dim) * block_size * head_dim +
block_offset * head_dim + head_offset;
copy_vector<scalar_t, VecSize>(x, key + addr_offset);
copy_vector<scalar_t, VecSize>(y, key + addr_offset + half_head_dim);
#pragma unroll
for (int j = 0; j < VecSize; j++) {
out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] -
y[j] * sin_ptr[j * 32 + shard_offset];
out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] +
x[j] * sin_ptr[j * 32 + shard_offset];
}
copy_vector<scalar_t, VecSize>(key_cache + target_id, out_x);
copy_vector<scalar_t, VecSize>(key_cache + target_id + half_head_dim,
out_y);
}
// apply value memcopy
apply_kv_memcopy<scalar_t, VecSize>(
value, value_cache, value_stride, token_id, block_id, head_num * head_dim,
block_size, block_offset, head_dim, half_head_dim);
}
template<typename scalar_t, int VecSize>
__global__ void rotary_embedding_and_cache_copy_kernel(
scalar_t* __restrict__ query,
scalar_t* __restrict__ key,
scalar_t* __restrict__ value,
const scalar_t* __restrict__ cos,
const scalar_t* __restrict__ sin,
scalar_t* __restrict__ key_cache,
scalar_t* __restrict__ value_cache,
const int* __restrict__ sequence_lengths,
const int* __restrict__ block_tables,
const int64_t query_stride,
const int64_t key_stride,
const int64_t value_stride,
const int64_t half_shard_element_num,
const int cos_stride,
const int sin_stride,
const int block_table_stride,
const int head_num,
const int head_dim,
const int kv_head_num,
const int block_size
) {
const int token_id = blockIdx.x;
const int half_head_dim = head_dim / 2;
const int shard_block_size = VecSize * 32;
extern __shared__ char shard_ptr[];
scalar_t *cos_ptr = (scalar_t*)shard_ptr;
scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
// apply cos_sin memcopy
cos_sin_memory_access<scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
__syncthreads();
//compute query
apply_emb_rotary_compute<scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
//compute key and copy kv
apply_k_rotary_emb_compute<scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size);
}
template<typename scalar_t, int VecSize>
__global__ void rotary_embedding_kernel(
scalar_t* __restrict__ query,
scalar_t* __restrict__ key,
const scalar_t* __restrict__ cos,
const scalar_t* __restrict__ sin,
const int64_t query_stride,
const int64_t key_stride,
const int64_t half_shard_element_num,
const int cos_stride,
const int sin_stride,
const int head_num,
const int head_dim,
const int kv_head_num
) {
const int token_id = blockIdx.x;
const int half_head_dim = head_dim / 2;
const int shard_block_size = VecSize * 32;
extern __shared__ char shard_ptr[];
scalar_t *cos_ptr = (scalar_t*)shard_ptr;
scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
// apply cos_sin memcopy
cos_sin_memory_access<scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
__syncthreads();
//compute query
apply_emb_rotary_compute<scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
//compute key
apply_emb_rotary_compute<scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
}
template<typename scalar_t>
void apply_rotary_embedding_and_cache_copy(
at::Tensor& query, // [num_tokens, head_num, head_dim]
at::Tensor& key, // [num_tokens, kv_head_num, head_dim]
at::Tensor& value, // [num_tokens, kv_head_num, head_dim]
at::Tensor& cos, // [num_tokens, head_dim]
at::Tensor& sin, // [num_tokens, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len]
{
int num_tokens = query.size(0);
int head_num = query.size(1);
int head_dim = query.size(2);
int kv_head_num = key.size(1);
int block_size = key_cache.size(2);
int64_t query_stride = query.stride(0);
int64_t key_stride = key.stride(0);
int64_t value_stride = value.stride(0);
int cos_stride = cos.stride(0);
int sin_stride = sin.stride(0);
int block_table_stride = block_tables.stride(0);
int vec_size = get_vec_size<scalar_t>(query);
if ((head_dim / 2) % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
vec_size = 1;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int thread_nums = head_num * head_dim / vec_size / 2;
const int shard_block_size = vec_size * 32 * 2;
dim3 grid(num_tokens);
dim3 block(std::min(thread_nums, 512));
int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ;
switch (vec_size) {
case 1:
rotary_embedding_and_cache_copy_kernel<scalar_t, 1><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
query_stride,
key_stride,
value_stride,
shard_element_num / 2,
cos_stride,
sin_stride,
block_table_stride,
head_num,
head_dim,
kv_head_num,
block_size
);
break;
case 2:
rotary_embedding_and_cache_copy_kernel<scalar_t, 2><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
query_stride,
key_stride,
value_stride,
shard_element_num / 2,
cos_stride,
sin_stride,
block_table_stride,
head_num,
head_dim,
kv_head_num,
block_size
);
break;
case 4:
rotary_embedding_and_cache_copy_kernel<scalar_t, 4><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
query_stride,
key_stride,
value_stride,
shard_element_num / 2,
cos_stride,
sin_stride,
block_table_stride,
head_num,
head_dim,
kv_head_num,
block_size
);
break;
default:
AT_ERROR("Unsupported vectorized size ", vec_size);
break;
}
AT_CUDA_CHECK(cudaGetLastError());
}
template<typename scalar_t>
void apply_rotary_embedding(
at::Tensor& query, // [total_tokens, head_num, head_dim]
at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
at::Tensor& cos, // [total_tokens, head_dim]
at::Tensor& sin // [total_tokens, head_dim]
){
int num_tokens = query.size(0);
int head_num = query.size(1);
int head_dim = query.size(2);
int kv_head_num = key.size(1);
int query_stride = query.stride(0);
int key_stride = key.stride(0);
int cos_stride = cos.stride(0);
int sin_stride = sin.stride(0);
int vec_size = get_vec_size<scalar_t>(query);
if ((head_dim / 2) % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
vec_size = 1;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int thread_nums = head_num * head_dim / vec_size / 2;
const int shard_block_size = vec_size * 32 * 2;
dim3 grid(num_tokens);
dim3 block(std::min(thread_nums, 512));
int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ;
switch (vec_size) {
case 1:
rotary_embedding_kernel<scalar_t, 1><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
query_stride,
key_stride,
shard_element_num / 2,
cos_stride,
sin_stride,
head_num,
head_dim,
kv_head_num
);
break;
case 2:
rotary_embedding_kernel<scalar_t, 2><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
query_stride,
key_stride,
shard_element_num / 2,
cos_stride,
sin_stride,
head_num,
head_dim,
kv_head_num
);
break;
case 4:
rotary_embedding_kernel<scalar_t, 4><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
query_stride,
key_stride,
shard_element_num / 2,
cos_stride,
sin_stride,
head_num,
head_dim,
kv_head_num
);
break;
default:
AT_ERROR("Unsupported vectorized size ", vec_size);
break;
}
AT_CUDA_CHECK(cudaGetLastError());
}
void rotary_embedding_and_cache_copy(
at::Tensor& query, // [num_tokens, head_num, head_dim]
at::Tensor& key, // [num_tokens, kv_head_num, head_dim]
at::Tensor& value, // [num_tokens, kv_head_num, head_dim]
at::Tensor& cos, // [num_tokens, head_dim]
at::Tensor& sin, // [num_tokens, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len]
{
DISPATCH_FLOAT_HALF_AND_BFLOAT(
query.scalar_type(),
"rotary_embedding_and_cache_copy",
apply_rotary_embedding_and_cache_copy<scalar_t>(
query,
key,
value,
cos,
sin,
key_cache,
value_cache,
sequence_lengths,
block_tables
);)
}
void rotary_embedding(
at::Tensor& query, // [total_tokens, head_num, head_dim]
at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
at::Tensor& cos, // [total_tokens, head_dim]
at::Tensor& sin // [total_tokens, head_dim]
){
DISPATCH_FLOAT_HALF_AND_BFLOAT(
query.scalar_type(),
"rotary_embedding",
apply_rotary_embedding<scalar_t>(
query,
key,
cos,
sin
);)
}

View File

@ -9,6 +9,23 @@ void decode_kv_cache_memcpy(
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables); // [batch_size, max_seq_len]
void rotary_embedding(
torch::Tensor& query, // [total_tokens, head_num, head_dim]
torch::Tensor& key, // [total_tokens, kv_head_num, head_dim]
torch::Tensor& cos, // [total_tokens, head_dim]
torch::Tensor& sin); // [total_tokens, head_dim]
void rotary_embedding_and_cache_copy(
torch::Tensor& query, // [num_tokens, head_num, head_dim]
torch::Tensor& key, // [num_tokens, kv_head_num, head_dim]
torch::Tensor& value, // [num_tokens, num_heads, head_dim]
torch::Tensor& cos, // [num_tokens, head_dim]
torch::Tensor& sin, // [num_tokens, head_dim]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_dim]
torch::Tensor&
value_cache, // [num_blocks, num_heads, block_size, head_dim]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables); // [batch_size, max_seq_len]
torch::Tensor silu_and_mul(const torch::Tensor& ins);
void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
@ -25,6 +42,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
"Copy the GPU memory of kvcache during the decode stage.");
m.def(
"rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy,
"performing Rotary Embedding-related calculations and KVCache Memcopy.");
m.def("rotary_embedding", &rotary_embedding,
"performing Rotary Embedding-related calculations.");
m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply");
m.def("rms_layernorm", &rms_layernorm,

View File

@ -12,6 +12,7 @@ class InferenceOpsCudaExtension(_CudaExtension):
for fname in [
"cuda/pybind/inference.cpp",
"cuda/decode_kv_cache_memcpy_kernel.cu",
"cuda/fused_rotary_emb_and_cache_kernel.cu",
"cuda/activation_kernel.cu",
"cuda/rms_layernorm_kernel.cu",
]

View File

@ -22,15 +22,11 @@ def setup_seed(seed):
def check_inference_engine(use_engine=False, prompt_template=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = (
LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
model = LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
.cuda()
.half()
)
).cuda()
model = model.eval()
inputs = [
@ -44,7 +40,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
top_k = 50
if use_engine:
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template)
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)

View File

@ -0,0 +1,91 @@
import pytest
import torch
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from colossalai.kernel.kernel_loader import InferenceOpsLoader
inference_ops = InferenceOpsLoader().load()
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2
from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb
@pytest.mark.parametrize("BATCH_SIZE", [4])
@pytest.mark.parametrize("SEQ_LEN", [64])
@pytest.mark.parametrize("H", [32])
@pytest.mark.parametrize("D", [64])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
torch.manual_seed(10)
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
# our crafted op equals to Transformers
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
emb = LlamaRotaryEmbedding(D)
cos, sin = emb(x0, TOTAL_TOKENS)
cos_2 = cos[:, : D // 2]
sin_2 = sin[:, : D // 2]
position_ids = torch.arange(TOTAL_TOKENS)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
assert torch.allclose(embd_x0, embd_stimulated_x)
# create data
block_size = 32
max_blocks_per_sequence = (TOTAL_TOKENS + block_size - 1) // block_size
q_shape = (TOTAL_TOKENS, H, D)
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
k_shape = (TOTAL_TOKENS, H, D)
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
cos_shape = (TOTAL_TOKENS, D // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
cache_shape = (BATCH_SIZE * max_blocks_per_sequence, H, block_size, D)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
v = torch.randn_like(k)
v_cache = torch.zeros_like(k_cache)
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v2(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_blocks_per_sequence, block_size
)
new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
new_q = torch.randn_like(new_k)
new_v = torch.randn_like(new_k)
kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda")
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
new_q_copy = new_q.clone()
new_k_copy = new_k.clone()
inference_ops.rotary_embedding_and_cache_copy(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables
)
inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin)
past_kv_seq_len = kv_seq_lengths - 1
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
offsets_in_block = past_kv_seq_len % block_size
k_target = k_cache[target_block_ids, :, offsets_in_block, :].squeeze()
k_source = new_k_copy.squeeze()
v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze()
v_source = new_v.squeeze()
assert torch.allclose(new_q, q_ref, atol=1e-6, rtol=1e-6)
assert torch.allclose(k_target, k_ref, atol=1e-6, rtol=1e-6)
assert torch.allclose(new_q_copy, q_ref, atol=1e-6, rtol=1e-6)
assert torch.allclose(new_k_copy, k_ref, atol=1e-6, rtol=1e-6)
assert k_target.shape == k_source.shape
assert torch.allclose(k_target, k_source, atol=1e-6, rtol=1e-6)
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source)
if __name__ == "__main__":
test_rotary_emb(16, 512, 4, 128, torch.float16)