mirror of https://github.com/hpcaitech/ColossalAI
[Inference/Feat] Add kvcache quant support for fused_rotary_embedding_cache_copy (#5680)
parent
5cd75ce4c7
commit
ef8e4ffe31
|
@ -4,6 +4,11 @@
|
|||
|
||||
#include "micros.h"
|
||||
|
||||
#if defined(COLOSSAL_WITH_CUDA)
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#endif
|
||||
|
||||
namespace colossalAI {
|
||||
namespace common {
|
||||
|
||||
|
@ -27,6 +32,18 @@ struct MPTypeTrait<at::BFloat16> {
|
|||
using Type = float;
|
||||
};
|
||||
|
||||
#if defined(COLOSSAL_WITH_CUDA)
|
||||
template <>
|
||||
struct MPTypeTrait<half> {
|
||||
using Type = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MPTypeTrait<__nv_bfloat16> {
|
||||
using Type = float;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <bool high_precision, typename T>
|
||||
struct ScalarTypeTrait {
|
||||
using Type =
|
||||
|
|
|
@ -56,6 +56,11 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMin, HOSTDEVICE,
|
|||
typename T)
|
||||
|
||||
#if defined(COLOSSAL_WITH_CUDA)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kMinus,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hsub(lhs, rhs);
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kAdd,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hadd(lhs, rhs);
|
||||
|
@ -71,6 +76,13 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16,
|
|||
DEVICE, STMTS_WRAPPER({
|
||||
return __hadd(lhs, rhs);
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16,
|
||||
__nv_bfloat16, BinaryOpType::kMinus,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hsub(lhs, rhs);
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, __nv_bfloat162,
|
||||
__nv_bfloat162, BinaryOpType::kAdd,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
|
@ -82,6 +94,13 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
|||
STMTS_WRAPPER({
|
||||
return __float2bfloat16(__bfloat162float(lhs) + __bfloat162float(rhs));
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, BinaryOpType::kMinus, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __float2bfloat16(__bfloat162float(lhs) - __bfloat162float(rhs));
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, __nv_bfloat162, __nv_bfloat162, BinaryOpType::kAdd, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
|
|
|
@ -94,6 +94,10 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat16, DEVICE,
|
|||
STMTS_WRAPPER({
|
||||
return __float2bfloat16_rn(val);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, float, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __bfloat162float(val);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::bfloat164, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::bfloat164 dst;
|
||||
|
|
|
@ -192,12 +192,6 @@ void context_kv_cache_memcpy(
|
|||
int max_seq_len_in_batch)
|
||||
{
|
||||
|
||||
TORCH_CHECK(key.scalar_type() == at::ScalarType::Float || key.scalar_type() == at::ScalarType::Half || key.scalar_type() == at::ScalarType::BFloat16,
|
||||
"Dtype of key should be float, half or bfloat16!");
|
||||
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == key.scalar_type(),
|
||||
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");
|
||||
|
||||
|
||||
#define _(T, CacheT) \
|
||||
apply_context_kv_cache_memcpy<T, CacheT>( \
|
||||
key, \
|
||||
|
|
|
@ -380,12 +380,6 @@ void flash_decoding_attention(
|
|||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
float scale) {
|
||||
|
||||
|
||||
TORCH_CHECK(query.scalar_type() == at::ScalarType::Float || query.scalar_type() == at::ScalarType::Half || query.scalar_type() == at::ScalarType::BFloat16,
|
||||
"Dtype of query should be float, half or bfloat16!");
|
||||
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == query.scalar_type(),
|
||||
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");
|
||||
|
||||
if(key_cache.scalar_type() == at::ScalarType::Byte)
|
||||
{
|
||||
switch (query.scalar_type()) {
|
||||
|
|
|
@ -5,20 +5,30 @@
|
|||
#include "utils/vec_copy.h"
|
||||
#include "common/micros.h"
|
||||
#include "common/mp_type_traits.h"
|
||||
#include "funcs/cast_functor.h"
|
||||
#include "funcs/binary_functor.h"
|
||||
|
||||
using colossalAI::cuda::utils::copy_vector;
|
||||
using colossalAI::cuda::utils::get_vec_size;
|
||||
using colossalAI::cuda::utils::copy;
|
||||
using colossalAI::funcs::CastFunctor;
|
||||
using colossalAI::funcs::BinaryOpFunctor;
|
||||
using colossalAI::funcs::BinaryOpType;
|
||||
|
||||
template <typename scalar_t, typename m_scalar_t, int VecSize>
|
||||
template <typename T, typename MT, int VecSize>
|
||||
__device__ void apply_emb_rotary_compute(
|
||||
scalar_t* __restrict__ src, const m_scalar_t* __restrict__ cos_ptr,
|
||||
const m_scalar_t* __restrict__ sin_ptr, const int64_t stride,
|
||||
T* __restrict__ src, const MT* __restrict__ cos_ptr,
|
||||
const MT* __restrict__ sin_ptr, const int64_t stride,
|
||||
const int token_id, const int shard_block_size, const int half_head_dim,
|
||||
const int head_num, const int head_dim) {
|
||||
scalar_t x[VecSize];
|
||||
scalar_t y[VecSize];
|
||||
scalar_t out_x[VecSize];
|
||||
scalar_t out_y[VecSize];
|
||||
BinaryOpFunctor<MT, MT, MT, BinaryOpType::kMul> mul;
|
||||
BinaryOpFunctor<MT, MT, MT, BinaryOpType::kMinus> sub;
|
||||
BinaryOpFunctor<MT, MT, MT, BinaryOpType::kAdd> add;
|
||||
|
||||
T x[VecSize];
|
||||
T y[VecSize];
|
||||
T out_x[VecSize];
|
||||
T out_y[VecSize];
|
||||
|
||||
for (int i = threadIdx.x * VecSize; i < head_num * half_head_dim;
|
||||
i += blockDim.x * VecSize) {
|
||||
|
@ -29,25 +39,25 @@ __device__ void apply_emb_rotary_compute(
|
|||
const int64_t addr_offset =
|
||||
token_id * stride + (i / half_head_dim) * head_dim + head_offset;
|
||||
|
||||
copy_vector<scalar_t, VecSize>(x, src + addr_offset);
|
||||
copy_vector<scalar_t, VecSize>(y, src + addr_offset + half_head_dim);
|
||||
copy<T, VecSize>(src + addr_offset, x);
|
||||
copy<T, VecSize>(src + addr_offset + half_head_dim, y);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VecSize; j++) {
|
||||
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x[j]) * cos_ptr[j * 32 + shard_offset] -
|
||||
static_cast<m_scalar_t>(y[j]) * sin_ptr[j * 32 + shard_offset]);
|
||||
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(y[j]) * cos_ptr[j * 32 + shard_offset] +
|
||||
static_cast<m_scalar_t>(x[j]) * sin_ptr[j * 32 + shard_offset]);
|
||||
out_x[j] = CastFunctor<MT, T>()(sub(mul(CastFunctor<T, MT>()(x[j]), cos_ptr[j * 32 + shard_offset]),
|
||||
mul(CastFunctor<T, MT>()(y[j]), sin_ptr[j * 32 + shard_offset])));
|
||||
out_y[j] = CastFunctor<MT, T>()(add(mul(CastFunctor<T, MT>()(y[j]), cos_ptr[j * 32 + shard_offset]),
|
||||
mul(CastFunctor<T, MT>()(x[j]), sin_ptr[j * 32 + shard_offset])));
|
||||
}
|
||||
|
||||
copy_vector<scalar_t, VecSize>(src + addr_offset, out_x);
|
||||
copy_vector<scalar_t, VecSize>(src + addr_offset + half_head_dim, out_y);
|
||||
copy<T, VecSize>(out_x, src + addr_offset);
|
||||
copy<T, VecSize>(out_y, src + addr_offset + half_head_dim);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int VecSize>
|
||||
template <typename T, typename CacheT, int VecSize>
|
||||
__device__ void apply_kv_memcopy(
|
||||
scalar_t* __restrict__ src, scalar_t* __restrict__ cache,
|
||||
T* __restrict__ src, CacheT* __restrict__ cache,
|
||||
const int64_t stride, const int token_id, const int block_id,
|
||||
const int hidden_size, const int block_size, const int block_offset,
|
||||
const int head_dim, const int half_head_dim) {
|
||||
|
@ -60,16 +70,15 @@ __device__ void apply_kv_memcopy(
|
|||
head_id * block_size * head_dim +
|
||||
block_offset * head_dim + head_offset;
|
||||
|
||||
copy_vector<scalar_t, VecSize>(cache + target_id, src + src_id);
|
||||
copy_vector<scalar_t, VecSize>(cache + target_id + half_head_dim,
|
||||
src + src_id + half_head_dim);
|
||||
copy<T, CacheT, VecSize>(src + src_id, cache + target_id);
|
||||
copy<T, CacheT, VecSize>(src + src_id + half_head_dim, cache + target_id + half_head_dim);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename m_scalar_t, int VecSize>
|
||||
template <typename T, typename MT, int VecSize>
|
||||
__device__ void cos_sin_memory_access(
|
||||
const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin,
|
||||
m_scalar_t* cos_ptr, m_scalar_t* sin_ptr, const int token_id,
|
||||
const T* __restrict__ cos, const T* __restrict__ sin,
|
||||
MT* cos_ptr, MT* sin_ptr, const int token_id,
|
||||
const int shard_block_size, const int cos_stride, const int sin_stride,
|
||||
const int half_head_dim) {
|
||||
for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) {
|
||||
|
@ -77,22 +86,26 @@ __device__ void cos_sin_memory_access(
|
|||
const int shard_offset = (i % shard_block_size) / VecSize;
|
||||
const int shard_head =
|
||||
(i / shard_block_size) * shard_block_size + i % VecSize * 32;
|
||||
cos_ptr[shard_head + shard_offset] = static_cast<m_scalar_t>(cos[token_id * cos_stride + i]);
|
||||
sin_ptr[shard_head + shard_offset] = static_cast<m_scalar_t>(sin[token_id * sin_stride + i]);
|
||||
cos_ptr[shard_head + shard_offset] = CastFunctor<T, MT>()(cos[token_id * cos_stride + i]);
|
||||
sin_ptr[shard_head + shard_offset] = CastFunctor<T, MT>()(sin[token_id * sin_stride + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename m_scalar_t, int VecSize>
|
||||
template <typename T, typename MT, typename CacheT, int VecSize>
|
||||
__device__ void apply_k_rotary_emb_compute(
|
||||
scalar_t* __restrict__ key, scalar_t* __restrict__ value,
|
||||
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
|
||||
const m_scalar_t* __restrict__ cos_ptr, const m_scalar_t* __restrict__ sin_ptr,
|
||||
T* __restrict__ key, T* __restrict__ value,
|
||||
CacheT* __restrict__ key_cache, CacheT* __restrict__ value_cache,
|
||||
const MT* __restrict__ cos_ptr, const MT* __restrict__ sin_ptr,
|
||||
const int* __restrict__ sequence_lengths,
|
||||
const int* __restrict__ block_tables, const int64_t key_stride,
|
||||
const int64_t value_stride, const int token_id,
|
||||
const int block_table_stride, const int head_num, const int head_dim,
|
||||
const int kv_head_num, const int block_size, const int x, const int half_head_dim,
|
||||
const int shard_block_size) {
|
||||
|
||||
BinaryOpFunctor<MT, MT, MT, BinaryOpType::kMul> mul;
|
||||
BinaryOpFunctor<MT, MT, MT, BinaryOpType::kMinus> sub;
|
||||
BinaryOpFunctor<MT, MT, MT, BinaryOpType::kAdd> add;
|
||||
const int seq_len = sequence_lengths[token_id] - 1;
|
||||
const int block_offset = seq_len % block_size;
|
||||
const int block_id =
|
||||
|
@ -102,10 +115,10 @@ __device__ void apply_k_rotary_emb_compute(
|
|||
return;
|
||||
}
|
||||
|
||||
scalar_t x0[VecSize];
|
||||
scalar_t x1[VecSize];
|
||||
scalar_t out_x[VecSize];
|
||||
scalar_t out_y[VecSize];
|
||||
T x0[VecSize];
|
||||
T x1[VecSize];
|
||||
T out_x[VecSize];
|
||||
T out_y[VecSize];
|
||||
|
||||
for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim;
|
||||
i += blockDim.x * VecSize) {
|
||||
|
@ -123,37 +136,36 @@ __device__ void apply_k_rotary_emb_compute(
|
|||
+ block_offset * x
|
||||
+ x_offset;
|
||||
|
||||
copy_vector<scalar_t, VecSize>(x0, key + addr_offset);
|
||||
copy_vector<scalar_t, VecSize>(x1, key + addr_offset + half_head_dim);
|
||||
copy<T, VecSize>(key + addr_offset, x0);
|
||||
copy<T, VecSize>(key + addr_offset + half_head_dim, x1);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VecSize; j++) {
|
||||
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x0[j]) * cos_ptr[j * 32 + shard_offset] -
|
||||
static_cast<m_scalar_t>(x1[j]) * sin_ptr[j * 32 + shard_offset]);
|
||||
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x1[j]) * cos_ptr[j * 32 + shard_offset] +
|
||||
static_cast<m_scalar_t>(x0[j]) * sin_ptr[j * 32 + shard_offset]);
|
||||
out_x[j] = CastFunctor<MT, T>()(sub(mul(CastFunctor<T, MT>()(x0[j]), cos_ptr[j * 32 + shard_offset]),
|
||||
mul(CastFunctor<T, MT>()(x1[j]), sin_ptr[j * 32 + shard_offset])));
|
||||
out_y[j] = CastFunctor<MT, T>()(add(mul(CastFunctor<T, MT>()(x1[j]), cos_ptr[j * 32 + shard_offset]),
|
||||
mul(CastFunctor<T, MT>()(x0[j]), sin_ptr[j * 32 + shard_offset])));
|
||||
}
|
||||
|
||||
copy_vector<scalar_t, VecSize>(key_cache + target_id, out_x);
|
||||
copy_vector<scalar_t, VecSize>(key_cache + target_id + half_head_dim * block_size,
|
||||
out_y);
|
||||
copy<T, CacheT, VecSize>(out_x, key_cache + target_id);
|
||||
copy<T, CacheT, VecSize>(out_y, key_cache + target_id + half_head_dim * block_size);
|
||||
}
|
||||
|
||||
// apply value memcopy
|
||||
apply_kv_memcopy<scalar_t, VecSize>(
|
||||
apply_kv_memcopy<T, CacheT, VecSize>(
|
||||
value, value_cache, value_stride, token_id, block_id, kv_head_num * head_dim,
|
||||
block_size, block_offset, head_dim, half_head_dim);
|
||||
}
|
||||
|
||||
template<typename scalar_t, typename m_scalar_t, int VecSize>
|
||||
template<typename T, typename MT, typename CacheT, int VecSize>
|
||||
__global__ void rotary_embedding_and_cache_copy_kernel(
|
||||
scalar_t* __restrict__ query,
|
||||
scalar_t* __restrict__ key,
|
||||
scalar_t* __restrict__ value,
|
||||
const scalar_t* __restrict__ cos,
|
||||
const scalar_t* __restrict__ sin,
|
||||
scalar_t* __restrict__ key_cache,
|
||||
scalar_t* __restrict__ value_cache,
|
||||
T* __restrict__ query,
|
||||
T* __restrict__ key,
|
||||
T* __restrict__ value,
|
||||
const T* __restrict__ cos,
|
||||
const T* __restrict__ sin,
|
||||
CacheT* __restrict__ key_cache,
|
||||
CacheT* __restrict__ value_cache,
|
||||
const int* __restrict__ sequence_lengths,
|
||||
const int* __restrict__ block_tables,
|
||||
const int64_t query_stride,
|
||||
|
@ -176,26 +188,26 @@ __global__ void rotary_embedding_and_cache_copy_kernel(
|
|||
|
||||
extern __shared__ char shard_ptr[];
|
||||
|
||||
m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr;
|
||||
m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
|
||||
MT *cos_ptr = reinterpret_cast<MT*>(shard_ptr);
|
||||
MT *sin_ptr = cos_ptr + half_shard_element_num;
|
||||
|
||||
// apply cos_sin memcopy
|
||||
cos_sin_memory_access<scalar_t, m_scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
|
||||
cos_sin_memory_access<T, MT, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
|
||||
__syncthreads();
|
||||
|
||||
//compute query
|
||||
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||
apply_emb_rotary_compute<T, MT, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||
|
||||
//compute key and copy kv
|
||||
apply_k_rotary_emb_compute<scalar_t, m_scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, x, half_head_dim, shard_block_size);
|
||||
apply_k_rotary_emb_compute<T, MT, CacheT, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, x, half_head_dim, shard_block_size);
|
||||
}
|
||||
|
||||
template<typename scalar_t, typename m_scalar_t, int VecSize>
|
||||
template<typename T, typename MT, int VecSize>
|
||||
__global__ void rotary_embedding_kernel(
|
||||
scalar_t* __restrict__ query,
|
||||
scalar_t* __restrict__ key,
|
||||
const scalar_t* __restrict__ cos,
|
||||
const scalar_t* __restrict__ sin,
|
||||
T* __restrict__ query,
|
||||
T* __restrict__ key,
|
||||
const T* __restrict__ cos,
|
||||
const T* __restrict__ sin,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride,
|
||||
const int64_t half_shard_element_num,
|
||||
|
@ -211,29 +223,29 @@ __global__ void rotary_embedding_kernel(
|
|||
|
||||
extern __shared__ char shard_ptr[];
|
||||
|
||||
m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr;
|
||||
m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
|
||||
MT *cos_ptr = (MT*)shard_ptr;
|
||||
MT *sin_ptr = cos_ptr + half_shard_element_num;
|
||||
|
||||
// apply cos_sin memcopy
|
||||
cos_sin_memory_access<scalar_t, m_scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
|
||||
cos_sin_memory_access<T, MT, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
|
||||
__syncthreads();
|
||||
|
||||
//compute query
|
||||
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||
apply_emb_rotary_compute<T, MT, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||
|
||||
//compute key
|
||||
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
|
||||
apply_emb_rotary_compute<T, MT, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
|
||||
}
|
||||
|
||||
#define ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(VEC_SIZE) \
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, VEC_SIZE><<<grid, block, shared_memory_size, stream>>>( \
|
||||
query.data_ptr<scalar_t>(), \
|
||||
key.data_ptr<scalar_t>(), \
|
||||
value.data_ptr<scalar_t>(), \
|
||||
cos.data_ptr<scalar_t>(), \
|
||||
sin.data_ptr<scalar_t>(), \
|
||||
key_cache.data_ptr<scalar_t>(), \
|
||||
value_cache.data_ptr<scalar_t>(), \
|
||||
rotary_embedding_and_cache_copy_kernel<T, MT, CacheT, VEC_SIZE><<<grid, block, shared_memory_size, stream>>>( \
|
||||
reinterpret_cast<T*>(query.data_ptr()), \
|
||||
reinterpret_cast<T*>(key.data_ptr()), \
|
||||
reinterpret_cast<T*>(value.data_ptr()), \
|
||||
reinterpret_cast<T*>(cos.data_ptr()), \
|
||||
reinterpret_cast<T*>(sin.data_ptr()), \
|
||||
reinterpret_cast<CacheT*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CacheT*>(value_cache.data_ptr()), \
|
||||
sequence_lengths.data_ptr<int>(), \
|
||||
block_tables.data_ptr<int>(), \
|
||||
query_stride, \
|
||||
|
@ -250,7 +262,7 @@ __global__ void rotary_embedding_kernel(
|
|||
x); \
|
||||
|
||||
|
||||
template<typename scalar_t, bool high_precision>
|
||||
template<typename T, typename CacheT, bool high_precision>
|
||||
void apply_rotary_embedding_and_cache_copy(
|
||||
at::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key, // [num_tokens, kv_head_num, head_dim]
|
||||
|
@ -276,9 +288,9 @@ void apply_rotary_embedding_and_cache_copy(
|
|||
int sin_stride = sin.stride(0);
|
||||
int block_table_stride = block_tables.stride(0);
|
||||
|
||||
using m_scalar_t = typename colossalAI::common::ScalarTypeTrait<high_precision, scalar_t>::Type;
|
||||
using MT = typename colossalAI::common::ScalarTypeTrait<high_precision, T>::Type;
|
||||
|
||||
int vec_size = get_vec_size<scalar_t>(query);
|
||||
int vec_size = get_vec_size<T>(query);
|
||||
|
||||
if ((head_dim / 2) % vec_size != 0) {
|
||||
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
|
||||
|
@ -293,7 +305,7 @@ void apply_rotary_embedding_and_cache_copy(
|
|||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(thread_nums, 512));
|
||||
int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size;
|
||||
const int shared_memory_size = shard_element_num * sizeof(m_scalar_t);
|
||||
const int shared_memory_size = shard_element_num * sizeof(MT);
|
||||
|
||||
switch (vec_size) {
|
||||
case 1:
|
||||
|
@ -313,7 +325,7 @@ void apply_rotary_embedding_and_cache_copy(
|
|||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
template<typename scalar_t, bool high_precision>
|
||||
template<typename T, bool high_precision>
|
||||
void apply_rotary_embedding(
|
||||
at::Tensor& query, // [total_tokens, head_num, head_dim]
|
||||
at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
|
||||
|
@ -330,9 +342,9 @@ void apply_rotary_embedding(
|
|||
int cos_stride = cos.stride(0);
|
||||
int sin_stride = sin.stride(0);
|
||||
|
||||
using m_scalar_t = typename colossalAI::common::ScalarTypeTrait<high_precision, scalar_t>::Type;
|
||||
using MT = typename colossalAI::common::ScalarTypeTrait<high_precision, T>::Type;
|
||||
|
||||
int vec_size = get_vec_size<scalar_t>(query);
|
||||
int vec_size = get_vec_size<T>(query);
|
||||
|
||||
if ((head_dim / 2) % vec_size != 0) {
|
||||
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
|
||||
|
@ -350,11 +362,11 @@ void apply_rotary_embedding(
|
|||
|
||||
switch (vec_size) {
|
||||
case 1:
|
||||
rotary_embedding_kernel<scalar_t, m_scalar_t, 1><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos.data_ptr<scalar_t>(),
|
||||
sin.data_ptr<scalar_t>(),
|
||||
rotary_embedding_kernel<T, MT, 1><<<grid, block, shard_element_num * sizeof(MT), stream>>>(
|
||||
query.data_ptr<T>(),
|
||||
key.data_ptr<T>(),
|
||||
cos.data_ptr<T>(),
|
||||
sin.data_ptr<T>(),
|
||||
query_stride,
|
||||
key_stride,
|
||||
shard_element_num / 2,
|
||||
|
@ -366,11 +378,11 @@ void apply_rotary_embedding(
|
|||
);
|
||||
break;
|
||||
case 2:
|
||||
rotary_embedding_kernel<scalar_t, m_scalar_t, 2><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos.data_ptr<scalar_t>(),
|
||||
sin.data_ptr<scalar_t>(),
|
||||
rotary_embedding_kernel<T, MT, 2><<<grid, block, shard_element_num * sizeof(MT), stream>>>(
|
||||
query.data_ptr<T>(),
|
||||
key.data_ptr<T>(),
|
||||
cos.data_ptr<T>(),
|
||||
sin.data_ptr<T>(),
|
||||
query_stride,
|
||||
key_stride,
|
||||
shard_element_num / 2,
|
||||
|
@ -382,11 +394,11 @@ void apply_rotary_embedding(
|
|||
);
|
||||
break;
|
||||
case 4:
|
||||
rotary_embedding_kernel<scalar_t, m_scalar_t, 4><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos.data_ptr<scalar_t>(),
|
||||
sin.data_ptr<scalar_t>(),
|
||||
rotary_embedding_kernel<T, MT, 4><<<grid, block, shard_element_num * sizeof(MT), stream>>>(
|
||||
query.data_ptr<T>(),
|
||||
key.data_ptr<T>(),
|
||||
cos.data_ptr<T>(),
|
||||
sin.data_ptr<T>(),
|
||||
query_stride,
|
||||
key_stride,
|
||||
shard_element_num / 2,
|
||||
|
@ -416,21 +428,81 @@ void rotary_embedding_and_cache_copy(
|
|||
at::Tensor& block_tables, // [batch_size, max_seq_len]
|
||||
bool high_precision)
|
||||
{
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(
|
||||
high_precision,
|
||||
query.scalar_type(),
|
||||
"rotary_embedding_and_cache_copy",
|
||||
apply_rotary_embedding_and_cache_copy<scalar_t, high_precision>(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
cos,
|
||||
sin,
|
||||
key_cache,
|
||||
value_cache,
|
||||
sequence_lengths,
|
||||
block_tables
|
||||
);)
|
||||
#define _(T, CacheT, HIGH_PRECISION) \
|
||||
apply_rotary_embedding_and_cache_copy<T, CacheT, HIGH_PRECISION>( \
|
||||
query, \
|
||||
key, \
|
||||
value, \
|
||||
cos, \
|
||||
sin, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
sequence_lengths, \
|
||||
block_tables);
|
||||
|
||||
if(key_cache.scalar_type() == at::ScalarType::Byte)
|
||||
{
|
||||
if(high_precision) {
|
||||
switch (key.scalar_type())
|
||||
{
|
||||
case at::ScalarType::Float:
|
||||
_(float, uint8_t, true)
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
_(half, uint8_t, true)
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
_(__nv_bfloat16, uint8_t, true)
|
||||
break;
|
||||
}
|
||||
}
|
||||
else {
|
||||
switch (key.scalar_type())
|
||||
{
|
||||
case at::ScalarType::Float:
|
||||
_(float, uint8_t, false)
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
_(half, uint8_t, false)
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
_(__nv_bfloat16, uint8_t, false)
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(high_precision) {
|
||||
switch (key.scalar_type())
|
||||
{
|
||||
case at::ScalarType::Float:
|
||||
_(float, float, true)
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
_(half, half, true)
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
_(__nv_bfloat16, __nv_bfloat16, true)
|
||||
break;
|
||||
}
|
||||
}
|
||||
else {
|
||||
switch (key.scalar_type())
|
||||
{
|
||||
case at::ScalarType::Float:
|
||||
_(float, float, false)
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
_(half, half, false)
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
_(__nv_bfloat16, __nv_bfloat16, false)
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
#undef _
|
||||
}
|
||||
|
||||
void rotary_embedding(
|
||||
|
|
|
@ -11,6 +11,7 @@ namespace colossalAI {
|
|||
namespace cuda {
|
||||
namespace utils {
|
||||
|
||||
// Note(LiuYang): Depreciated
|
||||
template <typename T, int vec_size>
|
||||
__device__ __inline__ void copy_vector(T *dst, const T *src) {
|
||||
using VT = typename common::VecTypeTrait<T, vec_size>::Type;
|
||||
|
@ -26,6 +27,7 @@ __device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
|
|||
*(reinterpret_cast<const float4 *>(src + 4));
|
||||
}
|
||||
|
||||
// Note(LiuYang): Depreciated
|
||||
template <typename T, int VecSize>
|
||||
__device__ __inline__ void copy_zero_vector(T *dst) {
|
||||
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
|
||||
|
@ -36,13 +38,12 @@ template <typename SrcT, typename DstT, int vec_size>
|
|||
__device__ __inline__ void copy(const SrcT *src, DstT *dst) {
|
||||
using SrcVT = typename common::VecTypeTrait<SrcT, vec_size>::Type;
|
||||
using DstVT = typename common::VecTypeTrait<DstT, vec_size>::Type;
|
||||
// Note(LiuYang): Here static_cast can't be used for cast between two pointer
|
||||
*(reinterpret_cast<DstVT *>(dst)) = funcs::CastFunctor<SrcVT, DstVT>()(
|
||||
*(reinterpret_cast<const SrcVT *>(src)));
|
||||
}
|
||||
|
||||
template <typename T, int vec_size>
|
||||
__device__ __inline__ void copy<T, T, vec_size>(const T *src, T *dst) {
|
||||
__device__ __inline__ void copy(const T *src, T *dst) {
|
||||
using VT = typename common::VecTypeTrait<T, vec_size>::Type;
|
||||
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue