[Inference/Feat] Add kvcache quant support for fused_rotary_embedding_cache_copy (#5680)

pull/5681/head
傅剑寒 2024-04-30 18:33:53 +08:00 committed by GitHub
parent 5cd75ce4c7
commit ef8e4ffe31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 226 additions and 125 deletions

View File

@ -4,6 +4,11 @@
#include "micros.h" #include "micros.h"
#if defined(COLOSSAL_WITH_CUDA)
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#endif
namespace colossalAI { namespace colossalAI {
namespace common { namespace common {
@ -27,6 +32,18 @@ struct MPTypeTrait<at::BFloat16> {
using Type = float; using Type = float;
}; };
#if defined(COLOSSAL_WITH_CUDA)
template <>
struct MPTypeTrait<half> {
using Type = float;
};
template <>
struct MPTypeTrait<__nv_bfloat16> {
using Type = float;
};
#endif
template <bool high_precision, typename T> template <bool high_precision, typename T>
struct ScalarTypeTrait { struct ScalarTypeTrait {
using Type = using Type =

View File

@ -56,6 +56,11 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMin, HOSTDEVICE,
typename T) typename T)
#if defined(COLOSSAL_WITH_CUDA) #if defined(COLOSSAL_WITH_CUDA)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kMinus,
DEVICE, STMTS_WRAPPER({
return __hsub(lhs, rhs);
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kAdd, COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kAdd,
DEVICE, STMTS_WRAPPER({ DEVICE, STMTS_WRAPPER({
return __hadd(lhs, rhs); return __hadd(lhs, rhs);
@ -71,6 +76,13 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16,
DEVICE, STMTS_WRAPPER({ DEVICE, STMTS_WRAPPER({
return __hadd(lhs, rhs); return __hadd(lhs, rhs);
})) }))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16,
__nv_bfloat16, BinaryOpType::kMinus,
DEVICE, STMTS_WRAPPER({
return __hsub(lhs, rhs);
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, __nv_bfloat162, COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, __nv_bfloat162,
__nv_bfloat162, BinaryOpType::kAdd, __nv_bfloat162, BinaryOpType::kAdd,
DEVICE, STMTS_WRAPPER({ DEVICE, STMTS_WRAPPER({
@ -82,6 +94,13 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
STMTS_WRAPPER({ STMTS_WRAPPER({
return __float2bfloat16(__bfloat162float(lhs) + __bfloat162float(rhs)); return __float2bfloat16(__bfloat162float(lhs) + __bfloat162float(rhs));
})) }))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, BinaryOpType::kMinus, DEVICE,
STMTS_WRAPPER({
return __float2bfloat16(__bfloat162float(lhs) - __bfloat162float(rhs));
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
__nv_bfloat162, __nv_bfloat162, __nv_bfloat162, BinaryOpType::kAdd, DEVICE, __nv_bfloat162, __nv_bfloat162, __nv_bfloat162, BinaryOpType::kAdd, DEVICE,
STMTS_WRAPPER({ STMTS_WRAPPER({

View File

@ -94,6 +94,10 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat16, DEVICE,
STMTS_WRAPPER({ STMTS_WRAPPER({
return __float2bfloat16_rn(val); return __float2bfloat16_rn(val);
})) }))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, float, DEVICE,
STMTS_WRAPPER({
return __bfloat162float(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::bfloat164, DEVICE, COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::bfloat164, DEVICE,
STMTS_WRAPPER({ STMTS_WRAPPER({
dtype::bfloat164 dst; dtype::bfloat164 dst;

View File

@ -192,12 +192,6 @@ void context_kv_cache_memcpy(
int max_seq_len_in_batch) int max_seq_len_in_batch)
{ {
TORCH_CHECK(key.scalar_type() == at::ScalarType::Float || key.scalar_type() == at::ScalarType::Half || key.scalar_type() == at::ScalarType::BFloat16,
"Dtype of key should be float, half or bfloat16!");
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == key.scalar_type(),
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");
#define _(T, CacheT) \ #define _(T, CacheT) \
apply_context_kv_cache_memcpy<T, CacheT>( \ apply_context_kv_cache_memcpy<T, CacheT>( \
key, \ key, \

View File

@ -380,12 +380,6 @@ void flash_decoding_attention(
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
float scale) { float scale) {
TORCH_CHECK(query.scalar_type() == at::ScalarType::Float || query.scalar_type() == at::ScalarType::Half || query.scalar_type() == at::ScalarType::BFloat16,
"Dtype of query should be float, half or bfloat16!");
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == query.scalar_type(),
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");
if(key_cache.scalar_type() == at::ScalarType::Byte) if(key_cache.scalar_type() == at::ScalarType::Byte)
{ {
switch (query.scalar_type()) { switch (query.scalar_type()) {

View File

@ -5,20 +5,30 @@
#include "utils/vec_copy.h" #include "utils/vec_copy.h"
#include "common/micros.h" #include "common/micros.h"
#include "common/mp_type_traits.h" #include "common/mp_type_traits.h"
#include "funcs/cast_functor.h"
#include "funcs/binary_functor.h"
using colossalAI::cuda::utils::copy_vector; using colossalAI::cuda::utils::copy_vector;
using colossalAI::cuda::utils::get_vec_size; using colossalAI::cuda::utils::get_vec_size;
using colossalAI::cuda::utils::copy;
using colossalAI::funcs::CastFunctor;
using colossalAI::funcs::BinaryOpFunctor;
using colossalAI::funcs::BinaryOpType;
template <typename scalar_t, typename m_scalar_t, int VecSize> template <typename T, typename MT, int VecSize>
__device__ void apply_emb_rotary_compute( __device__ void apply_emb_rotary_compute(
scalar_t* __restrict__ src, const m_scalar_t* __restrict__ cos_ptr, T* __restrict__ src, const MT* __restrict__ cos_ptr,
const m_scalar_t* __restrict__ sin_ptr, const int64_t stride, const MT* __restrict__ sin_ptr, const int64_t stride,
const int token_id, const int shard_block_size, const int half_head_dim, const int token_id, const int shard_block_size, const int half_head_dim,
const int head_num, const int head_dim) { const int head_num, const int head_dim) {
scalar_t x[VecSize]; BinaryOpFunctor<MT, MT, MT, BinaryOpType::kMul> mul;
scalar_t y[VecSize]; BinaryOpFunctor<MT, MT, MT, BinaryOpType::kMinus> sub;
scalar_t out_x[VecSize]; BinaryOpFunctor<MT, MT, MT, BinaryOpType::kAdd> add;
scalar_t out_y[VecSize];
T x[VecSize];
T y[VecSize];
T out_x[VecSize];
T out_y[VecSize];
for (int i = threadIdx.x * VecSize; i < head_num * half_head_dim; for (int i = threadIdx.x * VecSize; i < head_num * half_head_dim;
i += blockDim.x * VecSize) { i += blockDim.x * VecSize) {
@ -29,25 +39,25 @@ __device__ void apply_emb_rotary_compute(
const int64_t addr_offset = const int64_t addr_offset =
token_id * stride + (i / half_head_dim) * head_dim + head_offset; token_id * stride + (i / half_head_dim) * head_dim + head_offset;
copy_vector<scalar_t, VecSize>(x, src + addr_offset); copy<T, VecSize>(src + addr_offset, x);
copy_vector<scalar_t, VecSize>(y, src + addr_offset + half_head_dim); copy<T, VecSize>(src + addr_offset + half_head_dim, y);
#pragma unroll #pragma unroll
for (int j = 0; j < VecSize; j++) { for (int j = 0; j < VecSize; j++) {
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x[j]) * cos_ptr[j * 32 + shard_offset] - out_x[j] = CastFunctor<MT, T>()(sub(mul(CastFunctor<T, MT>()(x[j]), cos_ptr[j * 32 + shard_offset]),
static_cast<m_scalar_t>(y[j]) * sin_ptr[j * 32 + shard_offset]); mul(CastFunctor<T, MT>()(y[j]), sin_ptr[j * 32 + shard_offset])));
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(y[j]) * cos_ptr[j * 32 + shard_offset] + out_y[j] = CastFunctor<MT, T>()(add(mul(CastFunctor<T, MT>()(y[j]), cos_ptr[j * 32 + shard_offset]),
static_cast<m_scalar_t>(x[j]) * sin_ptr[j * 32 + shard_offset]); mul(CastFunctor<T, MT>()(x[j]), sin_ptr[j * 32 + shard_offset])));
} }
copy_vector<scalar_t, VecSize>(src + addr_offset, out_x); copy<T, VecSize>(out_x, src + addr_offset);
copy_vector<scalar_t, VecSize>(src + addr_offset + half_head_dim, out_y); copy<T, VecSize>(out_y, src + addr_offset + half_head_dim);
} }
} }
template <typename scalar_t, int VecSize> template <typename T, typename CacheT, int VecSize>
__device__ void apply_kv_memcopy( __device__ void apply_kv_memcopy(
scalar_t* __restrict__ src, scalar_t* __restrict__ cache, T* __restrict__ src, CacheT* __restrict__ cache,
const int64_t stride, const int token_id, const int block_id, const int64_t stride, const int token_id, const int block_id,
const int hidden_size, const int block_size, const int block_offset, const int hidden_size, const int block_size, const int block_offset,
const int head_dim, const int half_head_dim) { const int head_dim, const int half_head_dim) {
@ -60,16 +70,15 @@ __device__ void apply_kv_memcopy(
head_id * block_size * head_dim + head_id * block_size * head_dim +
block_offset * head_dim + head_offset; block_offset * head_dim + head_offset;
copy_vector<scalar_t, VecSize>(cache + target_id, src + src_id); copy<T, CacheT, VecSize>(src + src_id, cache + target_id);
copy_vector<scalar_t, VecSize>(cache + target_id + half_head_dim, copy<T, CacheT, VecSize>(src + src_id + half_head_dim, cache + target_id + half_head_dim);
src + src_id + half_head_dim);
} }
} }
template <typename scalar_t, typename m_scalar_t, int VecSize> template <typename T, typename MT, int VecSize>
__device__ void cos_sin_memory_access( __device__ void cos_sin_memory_access(
const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin, const T* __restrict__ cos, const T* __restrict__ sin,
m_scalar_t* cos_ptr, m_scalar_t* sin_ptr, const int token_id, MT* cos_ptr, MT* sin_ptr, const int token_id,
const int shard_block_size, const int cos_stride, const int sin_stride, const int shard_block_size, const int cos_stride, const int sin_stride,
const int half_head_dim) { const int half_head_dim) {
for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) { for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) {
@ -77,22 +86,26 @@ __device__ void cos_sin_memory_access(
const int shard_offset = (i % shard_block_size) / VecSize; const int shard_offset = (i % shard_block_size) / VecSize;
const int shard_head = const int shard_head =
(i / shard_block_size) * shard_block_size + i % VecSize * 32; (i / shard_block_size) * shard_block_size + i % VecSize * 32;
cos_ptr[shard_head + shard_offset] = static_cast<m_scalar_t>(cos[token_id * cos_stride + i]); cos_ptr[shard_head + shard_offset] = CastFunctor<T, MT>()(cos[token_id * cos_stride + i]);
sin_ptr[shard_head + shard_offset] = static_cast<m_scalar_t>(sin[token_id * sin_stride + i]); sin_ptr[shard_head + shard_offset] = CastFunctor<T, MT>()(sin[token_id * sin_stride + i]);
} }
} }
template <typename scalar_t, typename m_scalar_t, int VecSize> template <typename T, typename MT, typename CacheT, int VecSize>
__device__ void apply_k_rotary_emb_compute( __device__ void apply_k_rotary_emb_compute(
scalar_t* __restrict__ key, scalar_t* __restrict__ value, T* __restrict__ key, T* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, CacheT* __restrict__ key_cache, CacheT* __restrict__ value_cache,
const m_scalar_t* __restrict__ cos_ptr, const m_scalar_t* __restrict__ sin_ptr, const MT* __restrict__ cos_ptr, const MT* __restrict__ sin_ptr,
const int* __restrict__ sequence_lengths, const int* __restrict__ sequence_lengths,
const int* __restrict__ block_tables, const int64_t key_stride, const int* __restrict__ block_tables, const int64_t key_stride,
const int64_t value_stride, const int token_id, const int64_t value_stride, const int token_id,
const int block_table_stride, const int head_num, const int head_dim, const int block_table_stride, const int head_num, const int head_dim,
const int kv_head_num, const int block_size, const int x, const int half_head_dim, const int kv_head_num, const int block_size, const int x, const int half_head_dim,
const int shard_block_size) { const int shard_block_size) {
BinaryOpFunctor<MT, MT, MT, BinaryOpType::kMul> mul;
BinaryOpFunctor<MT, MT, MT, BinaryOpType::kMinus> sub;
BinaryOpFunctor<MT, MT, MT, BinaryOpType::kAdd> add;
const int seq_len = sequence_lengths[token_id] - 1; const int seq_len = sequence_lengths[token_id] - 1;
const int block_offset = seq_len % block_size; const int block_offset = seq_len % block_size;
const int block_id = const int block_id =
@ -102,10 +115,10 @@ __device__ void apply_k_rotary_emb_compute(
return; return;
} }
scalar_t x0[VecSize]; T x0[VecSize];
scalar_t x1[VecSize]; T x1[VecSize];
scalar_t out_x[VecSize]; T out_x[VecSize];
scalar_t out_y[VecSize]; T out_y[VecSize];
for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim; for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim;
i += blockDim.x * VecSize) { i += blockDim.x * VecSize) {
@ -123,37 +136,36 @@ __device__ void apply_k_rotary_emb_compute(
+ block_offset * x + block_offset * x
+ x_offset; + x_offset;
copy_vector<scalar_t, VecSize>(x0, key + addr_offset); copy<T, VecSize>(key + addr_offset, x0);
copy_vector<scalar_t, VecSize>(x1, key + addr_offset + half_head_dim); copy<T, VecSize>(key + addr_offset + half_head_dim, x1);
#pragma unroll #pragma unroll
for (int j = 0; j < VecSize; j++) { for (int j = 0; j < VecSize; j++) {
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x0[j]) * cos_ptr[j * 32 + shard_offset] - out_x[j] = CastFunctor<MT, T>()(sub(mul(CastFunctor<T, MT>()(x0[j]), cos_ptr[j * 32 + shard_offset]),
static_cast<m_scalar_t>(x1[j]) * sin_ptr[j * 32 + shard_offset]); mul(CastFunctor<T, MT>()(x1[j]), sin_ptr[j * 32 + shard_offset])));
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x1[j]) * cos_ptr[j * 32 + shard_offset] + out_y[j] = CastFunctor<MT, T>()(add(mul(CastFunctor<T, MT>()(x1[j]), cos_ptr[j * 32 + shard_offset]),
static_cast<m_scalar_t>(x0[j]) * sin_ptr[j * 32 + shard_offset]); mul(CastFunctor<T, MT>()(x0[j]), sin_ptr[j * 32 + shard_offset])));
} }
copy_vector<scalar_t, VecSize>(key_cache + target_id, out_x); copy<T, CacheT, VecSize>(out_x, key_cache + target_id);
copy_vector<scalar_t, VecSize>(key_cache + target_id + half_head_dim * block_size, copy<T, CacheT, VecSize>(out_y, key_cache + target_id + half_head_dim * block_size);
out_y);
} }
// apply value memcopy // apply value memcopy
apply_kv_memcopy<scalar_t, VecSize>( apply_kv_memcopy<T, CacheT, VecSize>(
value, value_cache, value_stride, token_id, block_id, kv_head_num * head_dim, value, value_cache, value_stride, token_id, block_id, kv_head_num * head_dim,
block_size, block_offset, head_dim, half_head_dim); block_size, block_offset, head_dim, half_head_dim);
} }
template<typename scalar_t, typename m_scalar_t, int VecSize> template<typename T, typename MT, typename CacheT, int VecSize>
__global__ void rotary_embedding_and_cache_copy_kernel( __global__ void rotary_embedding_and_cache_copy_kernel(
scalar_t* __restrict__ query, T* __restrict__ query,
scalar_t* __restrict__ key, T* __restrict__ key,
scalar_t* __restrict__ value, T* __restrict__ value,
const scalar_t* __restrict__ cos, const T* __restrict__ cos,
const scalar_t* __restrict__ sin, const T* __restrict__ sin,
scalar_t* __restrict__ key_cache, CacheT* __restrict__ key_cache,
scalar_t* __restrict__ value_cache, CacheT* __restrict__ value_cache,
const int* __restrict__ sequence_lengths, const int* __restrict__ sequence_lengths,
const int* __restrict__ block_tables, const int* __restrict__ block_tables,
const int64_t query_stride, const int64_t query_stride,
@ -176,26 +188,26 @@ __global__ void rotary_embedding_and_cache_copy_kernel(
extern __shared__ char shard_ptr[]; extern __shared__ char shard_ptr[];
m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr; MT *cos_ptr = reinterpret_cast<MT*>(shard_ptr);
m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num; MT *sin_ptr = cos_ptr + half_shard_element_num;
// apply cos_sin memcopy // apply cos_sin memcopy
cos_sin_memory_access<scalar_t, m_scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); cos_sin_memory_access<T, MT, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
__syncthreads(); __syncthreads();
//compute query //compute query
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); apply_emb_rotary_compute<T, MT, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
//compute key and copy kv //compute key and copy kv
apply_k_rotary_emb_compute<scalar_t, m_scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, x, half_head_dim, shard_block_size); apply_k_rotary_emb_compute<T, MT, CacheT, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, x, half_head_dim, shard_block_size);
} }
template<typename scalar_t, typename m_scalar_t, int VecSize> template<typename T, typename MT, int VecSize>
__global__ void rotary_embedding_kernel( __global__ void rotary_embedding_kernel(
scalar_t* __restrict__ query, T* __restrict__ query,
scalar_t* __restrict__ key, T* __restrict__ key,
const scalar_t* __restrict__ cos, const T* __restrict__ cos,
const scalar_t* __restrict__ sin, const T* __restrict__ sin,
const int64_t query_stride, const int64_t query_stride,
const int64_t key_stride, const int64_t key_stride,
const int64_t half_shard_element_num, const int64_t half_shard_element_num,
@ -211,29 +223,29 @@ __global__ void rotary_embedding_kernel(
extern __shared__ char shard_ptr[]; extern __shared__ char shard_ptr[];
m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr; MT *cos_ptr = (MT*)shard_ptr;
m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num; MT *sin_ptr = cos_ptr + half_shard_element_num;
// apply cos_sin memcopy // apply cos_sin memcopy
cos_sin_memory_access<scalar_t, m_scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); cos_sin_memory_access<T, MT, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
__syncthreads(); __syncthreads();
//compute query //compute query
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); apply_emb_rotary_compute<T, MT, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
//compute key //compute key
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim); apply_emb_rotary_compute<T, MT, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
} }
#define ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(VEC_SIZE) \ #define ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(VEC_SIZE) \
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, VEC_SIZE><<<grid, block, shared_memory_size, stream>>>( \ rotary_embedding_and_cache_copy_kernel<T, MT, CacheT, VEC_SIZE><<<grid, block, shared_memory_size, stream>>>( \
query.data_ptr<scalar_t>(), \ reinterpret_cast<T*>(query.data_ptr()), \
key.data_ptr<scalar_t>(), \ reinterpret_cast<T*>(key.data_ptr()), \
value.data_ptr<scalar_t>(), \ reinterpret_cast<T*>(value.data_ptr()), \
cos.data_ptr<scalar_t>(), \ reinterpret_cast<T*>(cos.data_ptr()), \
sin.data_ptr<scalar_t>(), \ reinterpret_cast<T*>(sin.data_ptr()), \
key_cache.data_ptr<scalar_t>(), \ reinterpret_cast<CacheT*>(key_cache.data_ptr()), \
value_cache.data_ptr<scalar_t>(), \ reinterpret_cast<CacheT*>(value_cache.data_ptr()), \
sequence_lengths.data_ptr<int>(), \ sequence_lengths.data_ptr<int>(), \
block_tables.data_ptr<int>(), \ block_tables.data_ptr<int>(), \
query_stride, \ query_stride, \
@ -250,7 +262,7 @@ __global__ void rotary_embedding_kernel(
x); \ x); \
template<typename scalar_t, bool high_precision> template<typename T, typename CacheT, bool high_precision>
void apply_rotary_embedding_and_cache_copy( void apply_rotary_embedding_and_cache_copy(
at::Tensor& query, // [num_tokens, head_num, head_dim] at::Tensor& query, // [num_tokens, head_num, head_dim]
at::Tensor& key, // [num_tokens, kv_head_num, head_dim] at::Tensor& key, // [num_tokens, kv_head_num, head_dim]
@ -276,9 +288,9 @@ void apply_rotary_embedding_and_cache_copy(
int sin_stride = sin.stride(0); int sin_stride = sin.stride(0);
int block_table_stride = block_tables.stride(0); int block_table_stride = block_tables.stride(0);
using m_scalar_t = typename colossalAI::common::ScalarTypeTrait<high_precision, scalar_t>::Type; using MT = typename colossalAI::common::ScalarTypeTrait<high_precision, T>::Type;
int vec_size = get_vec_size<scalar_t>(query); int vec_size = get_vec_size<T>(query);
if ((head_dim / 2) % vec_size != 0) { if ((head_dim / 2) % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize. // Disable vectorized loading optimization when head_dim is not divisible by VecSize.
@ -293,7 +305,7 @@ void apply_rotary_embedding_and_cache_copy(
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(thread_nums, 512)); dim3 block(std::min(thread_nums, 512));
int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size; int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size;
const int shared_memory_size = shard_element_num * sizeof(m_scalar_t); const int shared_memory_size = shard_element_num * sizeof(MT);
switch (vec_size) { switch (vec_size) {
case 1: case 1:
@ -313,7 +325,7 @@ void apply_rotary_embedding_and_cache_copy(
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
} }
template<typename scalar_t, bool high_precision> template<typename T, bool high_precision>
void apply_rotary_embedding( void apply_rotary_embedding(
at::Tensor& query, // [total_tokens, head_num, head_dim] at::Tensor& query, // [total_tokens, head_num, head_dim]
at::Tensor& key, // [total_tokens, kv_head_num, head_dim] at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
@ -330,9 +342,9 @@ void apply_rotary_embedding(
int cos_stride = cos.stride(0); int cos_stride = cos.stride(0);
int sin_stride = sin.stride(0); int sin_stride = sin.stride(0);
using m_scalar_t = typename colossalAI::common::ScalarTypeTrait<high_precision, scalar_t>::Type; using MT = typename colossalAI::common::ScalarTypeTrait<high_precision, T>::Type;
int vec_size = get_vec_size<scalar_t>(query); int vec_size = get_vec_size<T>(query);
if ((head_dim / 2) % vec_size != 0) { if ((head_dim / 2) % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize. // Disable vectorized loading optimization when head_dim is not divisible by VecSize.
@ -350,11 +362,11 @@ void apply_rotary_embedding(
switch (vec_size) { switch (vec_size) {
case 1: case 1:
rotary_embedding_kernel<scalar_t, m_scalar_t, 1><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>( rotary_embedding_kernel<T, MT, 1><<<grid, block, shard_element_num * sizeof(MT), stream>>>(
query.data_ptr<scalar_t>(), query.data_ptr<T>(),
key.data_ptr<scalar_t>(), key.data_ptr<T>(),
cos.data_ptr<scalar_t>(), cos.data_ptr<T>(),
sin.data_ptr<scalar_t>(), sin.data_ptr<T>(),
query_stride, query_stride,
key_stride, key_stride,
shard_element_num / 2, shard_element_num / 2,
@ -366,11 +378,11 @@ void apply_rotary_embedding(
); );
break; break;
case 2: case 2:
rotary_embedding_kernel<scalar_t, m_scalar_t, 2><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>( rotary_embedding_kernel<T, MT, 2><<<grid, block, shard_element_num * sizeof(MT), stream>>>(
query.data_ptr<scalar_t>(), query.data_ptr<T>(),
key.data_ptr<scalar_t>(), key.data_ptr<T>(),
cos.data_ptr<scalar_t>(), cos.data_ptr<T>(),
sin.data_ptr<scalar_t>(), sin.data_ptr<T>(),
query_stride, query_stride,
key_stride, key_stride,
shard_element_num / 2, shard_element_num / 2,
@ -382,11 +394,11 @@ void apply_rotary_embedding(
); );
break; break;
case 4: case 4:
rotary_embedding_kernel<scalar_t, m_scalar_t, 4><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>( rotary_embedding_kernel<T, MT, 4><<<grid, block, shard_element_num * sizeof(MT), stream>>>(
query.data_ptr<scalar_t>(), query.data_ptr<T>(),
key.data_ptr<scalar_t>(), key.data_ptr<T>(),
cos.data_ptr<scalar_t>(), cos.data_ptr<T>(),
sin.data_ptr<scalar_t>(), sin.data_ptr<T>(),
query_stride, query_stride,
key_stride, key_stride,
shard_element_num / 2, shard_element_num / 2,
@ -416,21 +428,81 @@ void rotary_embedding_and_cache_copy(
at::Tensor& block_tables, // [batch_size, max_seq_len] at::Tensor& block_tables, // [batch_size, max_seq_len]
bool high_precision) bool high_precision)
{ {
DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION( #define _(T, CacheT, HIGH_PRECISION) \
high_precision, apply_rotary_embedding_and_cache_copy<T, CacheT, HIGH_PRECISION>( \
query.scalar_type(), query, \
"rotary_embedding_and_cache_copy", key, \
apply_rotary_embedding_and_cache_copy<scalar_t, high_precision>( value, \
query, cos, \
key, sin, \
value, key_cache, \
cos, value_cache, \
sin, sequence_lengths, \
key_cache, block_tables);
value_cache,
sequence_lengths, if(key_cache.scalar_type() == at::ScalarType::Byte)
block_tables {
);) if(high_precision) {
switch (key.scalar_type())
{
case at::ScalarType::Float:
_(float, uint8_t, true)
break;
case at::ScalarType::Half:
_(half, uint8_t, true)
break;
case at::ScalarType::BFloat16:
_(__nv_bfloat16, uint8_t, true)
break;
}
}
else {
switch (key.scalar_type())
{
case at::ScalarType::Float:
_(float, uint8_t, false)
break;
case at::ScalarType::Half:
_(half, uint8_t, false)
break;
case at::ScalarType::BFloat16:
_(__nv_bfloat16, uint8_t, false)
break;
}
}
}
else
{
if(high_precision) {
switch (key.scalar_type())
{
case at::ScalarType::Float:
_(float, float, true)
break;
case at::ScalarType::Half:
_(half, half, true)
break;
case at::ScalarType::BFloat16:
_(__nv_bfloat16, __nv_bfloat16, true)
break;
}
}
else {
switch (key.scalar_type())
{
case at::ScalarType::Float:
_(float, float, false)
break;
case at::ScalarType::Half:
_(half, half, false)
break;
case at::ScalarType::BFloat16:
_(__nv_bfloat16, __nv_bfloat16, false)
break;
}
}
}
#undef _
} }
void rotary_embedding( void rotary_embedding(

View File

@ -11,6 +11,7 @@ namespace colossalAI {
namespace cuda { namespace cuda {
namespace utils { namespace utils {
// Note(LiuYang): Depreciated
template <typename T, int vec_size> template <typename T, int vec_size>
__device__ __inline__ void copy_vector(T *dst, const T *src) { __device__ __inline__ void copy_vector(T *dst, const T *src) {
using VT = typename common::VecTypeTrait<T, vec_size>::Type; using VT = typename common::VecTypeTrait<T, vec_size>::Type;
@ -26,6 +27,7 @@ __device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
*(reinterpret_cast<const float4 *>(src + 4)); *(reinterpret_cast<const float4 *>(src + 4));
} }
// Note(LiuYang): Depreciated
template <typename T, int VecSize> template <typename T, int VecSize>
__device__ __inline__ void copy_zero_vector(T *dst) { __device__ __inline__ void copy_zero_vector(T *dst) {
using VT = typename common::VecTypeTrait<T, VecSize>::Type; using VT = typename common::VecTypeTrait<T, VecSize>::Type;
@ -36,13 +38,12 @@ template <typename SrcT, typename DstT, int vec_size>
__device__ __inline__ void copy(const SrcT *src, DstT *dst) { __device__ __inline__ void copy(const SrcT *src, DstT *dst) {
using SrcVT = typename common::VecTypeTrait<SrcT, vec_size>::Type; using SrcVT = typename common::VecTypeTrait<SrcT, vec_size>::Type;
using DstVT = typename common::VecTypeTrait<DstT, vec_size>::Type; using DstVT = typename common::VecTypeTrait<DstT, vec_size>::Type;
// Note(LiuYang): Here static_cast can't be used for cast between two pointer
*(reinterpret_cast<DstVT *>(dst)) = funcs::CastFunctor<SrcVT, DstVT>()( *(reinterpret_cast<DstVT *>(dst)) = funcs::CastFunctor<SrcVT, DstVT>()(
*(reinterpret_cast<const SrcVT *>(src))); *(reinterpret_cast<const SrcVT *>(src)));
} }
template <typename T, int vec_size> template <typename T, int vec_size>
__device__ __inline__ void copy<T, T, vec_size>(const T *src, T *dst) { __device__ __inline__ void copy(const T *src, T *dst) {
using VT = typename common::VecTypeTrait<T, vec_size>::Type; using VT = typename common::VecTypeTrait<T, vec_size>::Type;
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src)); *(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
} }