mirror of https://github.com/hpcaitech/ColossalAI
[Inference/Kernel]Add get_cos_and_sin Kernel (#5528)
* Add get_cos_and_sin kernel * fix code comments * fix code typos * merge common codes of get_cos_and_sin kernel. * Fixed a typo * Changed 'asset allclose' to 'assert equal'.pull/5546/head
parent
934e31afb2
commit
04aca9e55b
|
@ -101,12 +101,22 @@ def llama_model_forward(
|
|||
use_cuda_kernel = False
|
||||
|
||||
hidden_states = self.embed_tokens(input_tokens_ids)
|
||||
if use_cuda_kernel and inputmetadata != torch.float32 and use_flash_attn2:
|
||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
if use_cuda_kernel:
|
||||
if inputmetadata != torch.float32 and use_flash_attn2:
|
||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
|
||||
hidden_dim = self._cos_cached.size(-1)
|
||||
total_length = hidden_states.size(0)
|
||||
cos = torch.empty((total_length, hidden_dim), dtype=self._cos_cached.dtype, device=self._cos_cached.device)
|
||||
sin = torch.empty((total_length, hidden_dim), dtype=self._sin_cached.dtype, device=self._sin_cached.device)
|
||||
inference_ops.get_cos_and_sin(
|
||||
self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts
|
||||
)
|
||||
cos_sin = (cos, sin)
|
||||
|
||||
else:
|
||||
cu_seqlens = None
|
||||
|
||||
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
|
||||
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
|
||||
|
||||
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
|
||||
|
||||
|
|
|
@ -0,0 +1,215 @@
|
|||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "utils/vector_copy_utils.h"
|
||||
#include "../common/micros.h"
|
||||
#include "stdio.h"
|
||||
|
||||
template <typename scalar_t, bool Aligned, int VecSize>
|
||||
__device__ void apply_cos_and_sin_memcopy(
|
||||
scalar_t* __restrict__ cos,
|
||||
scalar_t* __restrict__ sin,
|
||||
const scalar_t* __restrict__ cos_cache_ptr,
|
||||
const scalar_t* __restrict__ sin_cache_ptr,
|
||||
const int* __restrict__ sequence_lengths,
|
||||
const int head_dim,
|
||||
const int dest_offset_id,
|
||||
const int src_offset_id
|
||||
) {
|
||||
|
||||
int begin_id = threadIdx.x * VecSize;
|
||||
|
||||
for (; begin_id <= head_dim - VecSize; begin_id += blockDim.x){
|
||||
copy_vector<scalar_t, VecSize>(cos + dest_offset_id + begin_id, cos_cache_ptr + src_offset_id + begin_id);
|
||||
copy_vector<scalar_t, VecSize>(sin + dest_offset_id + begin_id, sin_cache_ptr + src_offset_id + begin_id);
|
||||
}
|
||||
|
||||
if (!Aligned) {
|
||||
for (; begin_id < head_dim; ++begin_id ) {
|
||||
cos[dest_offset_id + begin_id] = cos_cache_ptr[src_offset_id + begin_id];
|
||||
sin[dest_offset_id + begin_id] = sin_cache_ptr[src_offset_id + begin_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool Aligned, int VecSize>
|
||||
__global__ void apply_get_context_cos_and_sin_kernel(
|
||||
scalar_t* __restrict__ cos,
|
||||
scalar_t* __restrict__ sin,
|
||||
const scalar_t* __restrict__ cos_cache_ptr,
|
||||
const scalar_t* __restrict__ sin_cache_ptr,
|
||||
const int* __restrict__ sequence_lengths,
|
||||
const int* __restrict__ cumsum_lengths,
|
||||
const int batch_size,
|
||||
const int head_dim
|
||||
) {
|
||||
int token_id = blockIdx.x;
|
||||
if ( token_id >= sequence_lengths[blockIdx.y] ) {
|
||||
return ;
|
||||
}
|
||||
|
||||
int src_offset_id = token_id * head_dim;
|
||||
int dest_offset_id = src_offset_id;
|
||||
|
||||
if (blockIdx.y > 0) {
|
||||
dest_offset_id += cumsum_lengths[blockIdx.y - 1] * head_dim;
|
||||
}
|
||||
|
||||
apply_cos_and_sin_memcopy<scalar_t, Aligned, VecSize>(
|
||||
cos,
|
||||
sin,
|
||||
cos_cache_ptr,
|
||||
sin_cache_ptr,
|
||||
sequence_lengths,
|
||||
head_dim,
|
||||
dest_offset_id,
|
||||
src_offset_id
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool Aligned, int VecSize>
|
||||
__global__ void apply_get_decode_cos_and_sin_kernel(
|
||||
scalar_t* __restrict__ cos,
|
||||
scalar_t* __restrict__ sin,
|
||||
const scalar_t* __restrict__ cos_cache_ptr,
|
||||
const scalar_t* __restrict__ sin_cache_ptr,
|
||||
const int* __restrict__ sequence_lengths,
|
||||
const int batch_size,
|
||||
const int head_dim
|
||||
) {
|
||||
int src_offset_id = ( sequence_lengths[blockIdx.y] - 1 ) * head_dim;
|
||||
int dest_offset_id = blockIdx.y * head_dim;
|
||||
|
||||
apply_cos_and_sin_memcopy<scalar_t, Aligned, VecSize>(
|
||||
cos,
|
||||
sin,
|
||||
cos_cache_ptr,
|
||||
sin_cache_ptr,
|
||||
sequence_lengths,
|
||||
head_dim,
|
||||
dest_offset_id,
|
||||
src_offset_id
|
||||
);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
void apply_get_cos_and_sin(
|
||||
at::Tensor& cos_cache, // [max_rotary_position, head_dim]
|
||||
at::Tensor& sin_cache, // [max_rotary_position, head_dim]
|
||||
at::Tensor& cos, // [num_tokens, head_dim]
|
||||
at::Tensor& sin, // [num_tokens, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
int max_seq_len_in_batch,
|
||||
bool is_prompts
|
||||
) {
|
||||
int token_num = cos.size(0);
|
||||
int head_dim = cos.size(1);
|
||||
int batch_size = sequence_lengths.size(0);
|
||||
|
||||
at::Tensor cumsum_lengths;
|
||||
|
||||
int vec_size = get_vec_size<scalar_t>(cos);
|
||||
|
||||
bool aligned = true;
|
||||
if (head_dim % vec_size != 0) {
|
||||
aligned = false;
|
||||
}
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
int block_size_y;
|
||||
int block_size_x;
|
||||
|
||||
if (is_prompts) {
|
||||
block_size_y = batch_size;
|
||||
block_size_x = max_seq_len_in_batch;
|
||||
// TODO: The cumsum operation can be fused into get_cos_and_sin kernel later on.
|
||||
cumsum_lengths = torch::cumsum(sequence_lengths, 0, torch::kInt32);
|
||||
}
|
||||
else{
|
||||
block_size_y = batch_size;
|
||||
block_size_x = 1;
|
||||
}
|
||||
|
||||
int thread_nums = (head_dim + vec_size - 1) / vec_size;
|
||||
|
||||
dim3 grid(block_size_x, block_size_y);
|
||||
dim3 block(std::min(thread_nums, 512));
|
||||
|
||||
#define GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, __vec_size) \
|
||||
do { \
|
||||
if (is_prompts){ \
|
||||
apply_get_context_cos_and_sin_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
|
||||
cos.data_ptr<scalar_t>(), \
|
||||
sin.data_ptr<scalar_t>(), \
|
||||
cos_cache.data_ptr<scalar_t>(), \
|
||||
sin_cache.data_ptr<scalar_t>(), \
|
||||
sequence_lengths.data_ptr<int>(), \
|
||||
cumsum_lengths.data_ptr<int>(), \
|
||||
batch_size, \
|
||||
head_dim \
|
||||
); \
|
||||
} \
|
||||
else { \
|
||||
apply_get_decode_cos_and_sin_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
|
||||
cos.data_ptr<scalar_t>(), \
|
||||
sin.data_ptr<scalar_t>(), \
|
||||
cos_cache.data_ptr<scalar_t>(), \
|
||||
sin_cache.data_ptr<scalar_t>(), \
|
||||
sequence_lengths.data_ptr<int>(), \
|
||||
batch_size, \
|
||||
head_dim \
|
||||
); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
#define GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned) \
|
||||
do { \
|
||||
switch (vec_size) { \
|
||||
case 1: \
|
||||
GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 1); \
|
||||
break; \
|
||||
case 2: \
|
||||
GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 2); \
|
||||
break; \
|
||||
case 4: \
|
||||
GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 4); \
|
||||
break; \
|
||||
default: \
|
||||
AT_ERROR("Unsupported vectorized size ", vec_size); \
|
||||
break; \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
if (aligned) {
|
||||
GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(true);
|
||||
}
|
||||
else {
|
||||
GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(false);
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void get_cos_and_sin(
|
||||
at::Tensor& cos_cache, // [max_rotary_position, head_dim]
|
||||
at::Tensor& sin_cache, // [max_rotary_position, head_dim]
|
||||
at::Tensor& cos, // [num_tokens, head_dim]
|
||||
at::Tensor& sin, // [num_tokens, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
int max_seq_len_in_batch,
|
||||
bool is_prompts
|
||||
) {
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
cos.scalar_type(),
|
||||
"get_cos_and_sin",
|
||||
apply_get_cos_and_sin<scalar_t>(
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
cos,
|
||||
sin,
|
||||
sequence_lengths,
|
||||
max_seq_len_in_batch,
|
||||
is_prompts
|
||||
);)
|
||||
}
|
|
@ -51,6 +51,13 @@ void fused_add_rms_layernorm(torch::Tensor& input, // [..., hidden_size]
|
|||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon);
|
||||
|
||||
void get_cos_and_sin(at::Tensor& cos_cache, // [max_rotary_position, head_dim]
|
||||
at::Tensor& sin_cache, // [max_rotary_position, head_dim]
|
||||
at::Tensor& cos, // [num_tokens, head_dim]
|
||||
at::Tensor& sin, // [num_tokens, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
int max_seq_len_in_batch, bool is_prompts);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
||||
"Copy the GPU memory of kvcache during the decode stage.");
|
||||
|
@ -60,10 +67,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
|
||||
m.def(
|
||||
"rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy,
|
||||
"performing Rotary Embedding-related calculations and KVCache Memcopy.");
|
||||
"Performing Rotary Embedding-related calculations and KVCache Memcopy.");
|
||||
|
||||
m.def("rotary_embedding", &rotary_embedding,
|
||||
"performing Rotary Embedding-related calculations.");
|
||||
"Performing Rotary Embedding-related calculations.");
|
||||
|
||||
m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply");
|
||||
|
||||
|
@ -72,4 +79,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
|
||||
m.def("fused_add_rms_layernorm", &fused_add_rms_layernorm,
|
||||
"In-place fused Add and RMS Normalization.");
|
||||
|
||||
m.def("get_cos_and_sin", &get_cos_and_sin,
|
||||
"Get cos and sin from the cache.");
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ class InferenceOpsCudaExtension(_CudaExtension):
|
|||
"cuda/fused_rotary_emb_and_cache_kernel.cu",
|
||||
"cuda/activation_kernel.cu",
|
||||
"cuda/rms_layernorm_kernel.cu",
|
||||
"cuda/get_cos_and_sin_kernel.cu",
|
||||
]
|
||||
]
|
||||
return ret
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
|
||||
def numpy_equal(x, y):
|
||||
x_numpy = x.detach().cpu().numpy()
|
||||
y_numpy = y.detach().cpu().numpy()
|
||||
|
||||
np.testing.assert_equal(x_numpy, y_numpy)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("BATCH_SIZE", [4])
|
||||
@pytest.mark.parametrize("MAX_SEQ_LEN", [64])
|
||||
@pytest.mark.parametrize("HEAD_DIM", [64])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_get_cos_and_sin(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype):
|
||||
MAX_TOTAL_TOKENS = BATCH_SIZE * MAX_SEQ_LEN
|
||||
cos_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
|
||||
sin_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
|
||||
lengths = torch.randint(2, MAX_SEQ_LEN, (BATCH_SIZE,), device="cuda").to(torch.int32)
|
||||
|
||||
max_seq_len_in_batch = lengths.max()
|
||||
|
||||
# prefill
|
||||
cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
|
||||
|
||||
cos = torch.zeros_like(cos_ref)
|
||||
sin = torch.zeros_like(sin_ref)
|
||||
|
||||
inference_ops.get_cos_and_sin(cos_cache, sin_cache, cos, sin, lengths, max_seq_len_in_batch, True)
|
||||
|
||||
numpy_equal(cos, cos_ref)
|
||||
numpy_equal(sin, sin_ref)
|
||||
|
||||
# decoding
|
||||
ncos_ref, nsin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
|
||||
|
||||
cos = torch.zeros_like(ncos_ref)
|
||||
sin = torch.zeros_like(nsin_ref)
|
||||
|
||||
inference_ops.get_cos_and_sin(cos_cache, sin_cache, cos, sin, lengths, max_seq_len_in_batch, False)
|
||||
numpy_equal(cos, ncos_ref)
|
||||
numpy_equal(sin, nsin_ref)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_get_cos_and_sin(16, 4096, 256, torch.float16)
|
Loading…
Reference in New Issue