[Inference/Feat] Add kvcache quant support for fused_rotary_embedding_cache_copy (#5680)

pull/5681/head
傅剑寒 2024-04-30 18:33:53 +08:00 committed by GitHub
parent 5cd75ce4c7
commit ef8e4ffe31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 226 additions and 125 deletions

View File

@ -4,6 +4,11 @@
#include "micros.h"
#if defined(COLOSSAL_WITH_CUDA)
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#endif
namespace colossalAI {
namespace common {
@ -27,6 +32,18 @@ struct MPTypeTrait<at::BFloat16> {
using Type = float;
};
#if defined(COLOSSAL_WITH_CUDA)
template <>
struct MPTypeTrait<half> {
using Type = float;
};
template <>
struct MPTypeTrait<__nv_bfloat16> {
using Type = float;
};
#endif
template <bool high_precision, typename T>
struct ScalarTypeTrait {
using Type =

View File

@ -56,6 +56,11 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMin, HOSTDEVICE,
typename T)
#if defined(COLOSSAL_WITH_CUDA)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kMinus,
DEVICE, STMTS_WRAPPER({
return __hsub(lhs, rhs);
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kAdd,
DEVICE, STMTS_WRAPPER({
return __hadd(lhs, rhs);
@ -71,6 +76,13 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16,
DEVICE, STMTS_WRAPPER({
return __hadd(lhs, rhs);
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16,
__nv_bfloat16, BinaryOpType::kMinus,
DEVICE, STMTS_WRAPPER({
return __hsub(lhs, rhs);
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, __nv_bfloat162,
__nv_bfloat162, BinaryOpType::kAdd,
DEVICE, STMTS_WRAPPER({
@ -82,6 +94,13 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
STMTS_WRAPPER({
return __float2bfloat16(__bfloat162float(lhs) + __bfloat162float(rhs));
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, BinaryOpType::kMinus, DEVICE,
STMTS_WRAPPER({
return __float2bfloat16(__bfloat162float(lhs) - __bfloat162float(rhs));
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
__nv_bfloat162, __nv_bfloat162, __nv_bfloat162, BinaryOpType::kAdd, DEVICE,
STMTS_WRAPPER({

View File

@ -94,6 +94,10 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat16, DEVICE,
STMTS_WRAPPER({
return __float2bfloat16_rn(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, float, DEVICE,
STMTS_WRAPPER({
return __bfloat162float(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::bfloat164, DEVICE,
STMTS_WRAPPER({
dtype::bfloat164 dst;

View File

@ -192,12 +192,6 @@ void context_kv_cache_memcpy(
int max_seq_len_in_batch)
{
TORCH_CHECK(key.scalar_type() == at::ScalarType::Float || key.scalar_type() == at::ScalarType::Half || key.scalar_type() == at::ScalarType::BFloat16,
"Dtype of key should be float, half or bfloat16!");
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == key.scalar_type(),
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");
#define _(T, CacheT) \
apply_context_kv_cache_memcpy<T, CacheT>( \
key, \

View File

@ -380,12 +380,6 @@ void flash_decoding_attention(
const c10::optional<torch::Tensor>& alibi_slopes,
float scale) {
TORCH_CHECK(query.scalar_type() == at::ScalarType::Float || query.scalar_type() == at::ScalarType::Half || query.scalar_type() == at::ScalarType::BFloat16,
"Dtype of query should be float, half or bfloat16!");
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == query.scalar_type(),
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");
if(key_cache.scalar_type() == at::ScalarType::Byte)
{
switch (query.scalar_type()) {

View File

@ -5,20 +5,30 @@
#include "utils/vec_copy.h"
#include "common/micros.h"
#include "common/mp_type_traits.h"
#include "funcs/cast_functor.h"
#include "funcs/binary_functor.h"
using colossalAI::cuda::utils::copy_vector;
using colossalAI::cuda::utils::get_vec_size;
using colossalAI::cuda::utils::copy;
using colossalAI::funcs::CastFunctor;
using colossalAI::funcs::BinaryOpFunctor;
using colossalAI::funcs::BinaryOpType;
template <typename scalar_t, typename m_scalar_t, int VecSize>
template <typename T, typename MT, int VecSize>
__device__ void apply_emb_rotary_compute(
scalar_t* __restrict__ src, const m_scalar_t* __restrict__ cos_ptr,
const m_scalar_t* __restrict__ sin_ptr, const int64_t stride,
T* __restrict__ src, const MT* __restrict__ cos_ptr,
const MT* __restrict__ sin_ptr, const int64_t stride,
const int token_id, const int shard_block_size, const int half_head_dim,
const int head_num, const int head_dim) {
scalar_t x[VecSize];
scalar_t y[VecSize];
scalar_t out_x[VecSize];
scalar_t out_y[VecSize];
BinaryOpFunctor<MT, MT, MT, BinaryOpType::kMul> mul;
BinaryOpFunctor<MT, MT, MT, BinaryOpType::kMinus> sub;
BinaryOpFunctor<MT, MT, MT, BinaryOpType::kAdd> add;
T x[VecSize];
T y[VecSize];
T out_x[VecSize];
T out_y[VecSize];
for (int i = threadIdx.x * VecSize; i < head_num * half_head_dim;
i += blockDim.x * VecSize) {
@ -29,25 +39,25 @@ __device__ void apply_emb_rotary_compute(
const int64_t addr_offset =
token_id * stride + (i / half_head_dim) * head_dim + head_offset;
copy_vector<scalar_t, VecSize>(x, src + addr_offset);
copy_vector<scalar_t, VecSize>(y, src + addr_offset + half_head_dim);
copy<T, VecSize>(src + addr_offset, x);
copy<T, VecSize>(src + addr_offset + half_head_dim, y);
#pragma unroll
for (int j = 0; j < VecSize; j++) {
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x[j]) * cos_ptr[j * 32 + shard_offset] -
static_cast<m_scalar_t>(y[j]) * sin_ptr[j * 32 + shard_offset]);
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(y[j]) * cos_ptr[j * 32 + shard_offset] +
static_cast<m_scalar_t>(x[j]) * sin_ptr[j * 32 + shard_offset]);
out_x[j] = CastFunctor<MT, T>()(sub(mul(CastFunctor<T, MT>()(x[j]), cos_ptr[j * 32 + shard_offset]),
mul(CastFunctor<T, MT>()(y[j]), sin_ptr[j * 32 + shard_offset])));
out_y[j] = CastFunctor<MT, T>()(add(mul(CastFunctor<T, MT>()(y[j]), cos_ptr[j * 32 + shard_offset]),
mul(CastFunctor<T, MT>()(x[j]), sin_ptr[j * 32 + shard_offset])));
}
copy_vector<scalar_t, VecSize>(src + addr_offset, out_x);
copy_vector<scalar_t, VecSize>(src + addr_offset + half_head_dim, out_y);
copy<T, VecSize>(out_x, src + addr_offset);
copy<T, VecSize>(out_y, src + addr_offset + half_head_dim);
}
}
template <typename scalar_t, int VecSize>
template <typename T, typename CacheT, int VecSize>
__device__ void apply_kv_memcopy(
scalar_t* __restrict__ src, scalar_t* __restrict__ cache,
T* __restrict__ src, CacheT* __restrict__ cache,
const int64_t stride, const int token_id, const int block_id,
const int hidden_size, const int block_size, const int block_offset,
const int head_dim, const int half_head_dim) {
@ -60,16 +70,15 @@ __device__ void apply_kv_memcopy(
head_id * block_size * head_dim +
block_offset * head_dim + head_offset;
copy_vector<scalar_t, VecSize>(cache + target_id, src + src_id);
copy_vector<scalar_t, VecSize>(cache + target_id + half_head_dim,
src + src_id + half_head_dim);
copy<T, CacheT, VecSize>(src + src_id, cache + target_id);
copy<T, CacheT, VecSize>(src + src_id + half_head_dim, cache + target_id + half_head_dim);
}
}
template <typename scalar_t, typename m_scalar_t, int VecSize>
template <typename T, typename MT, int VecSize>
__device__ void cos_sin_memory_access(
const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin,
m_scalar_t* cos_ptr, m_scalar_t* sin_ptr, const int token_id,
const T* __restrict__ cos, const T* __restrict__ sin,
MT* cos_ptr, MT* sin_ptr, const int token_id,
const int shard_block_size, const int cos_stride, const int sin_stride,
const int half_head_dim) {
for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) {
@ -77,22 +86,26 @@ __device__ void cos_sin_memory_access(
const int shard_offset = (i % shard_block_size) / VecSize;
const int shard_head =
(i / shard_block_size) * shard_block_size + i % VecSize * 32;
cos_ptr[shard_head + shard_offset] = static_cast<m_scalar_t>(cos[token_id * cos_stride + i]);
sin_ptr[shard_head + shard_offset] = static_cast<m_scalar_t>(sin[token_id * sin_stride + i]);
cos_ptr[shard_head + shard_offset] = CastFunctor<T, MT>()(cos[token_id * cos_stride + i]);
sin_ptr[shard_head + shard_offset] = CastFunctor<T, MT>()(sin[token_id * sin_stride + i]);
}
}
template <typename scalar_t, typename m_scalar_t, int VecSize>
template <typename T, typename MT, typename CacheT, int VecSize>
__device__ void apply_k_rotary_emb_compute(
scalar_t* __restrict__ key, scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const m_scalar_t* __restrict__ cos_ptr, const m_scalar_t* __restrict__ sin_ptr,
T* __restrict__ key, T* __restrict__ value,
CacheT* __restrict__ key_cache, CacheT* __restrict__ value_cache,
const MT* __restrict__ cos_ptr, const MT* __restrict__ sin_ptr,
const int* __restrict__ sequence_lengths,
const int* __restrict__ block_tables, const int64_t key_stride,
const int64_t value_stride, const int token_id,
const int block_table_stride, const int head_num, const int head_dim,
const int kv_head_num, const int block_size, const int x, const int half_head_dim,
const int shard_block_size) {
BinaryOpFunctor<MT, MT, MT, BinaryOpType::kMul> mul;
BinaryOpFunctor<MT, MT, MT, BinaryOpType::kMinus> sub;
BinaryOpFunctor<MT, MT, MT, BinaryOpType::kAdd> add;
const int seq_len = sequence_lengths[token_id] - 1;
const int block_offset = seq_len % block_size;
const int block_id =
@ -102,10 +115,10 @@ __device__ void apply_k_rotary_emb_compute(
return;
}
scalar_t x0[VecSize];
scalar_t x1[VecSize];
scalar_t out_x[VecSize];
scalar_t out_y[VecSize];
T x0[VecSize];
T x1[VecSize];
T out_x[VecSize];
T out_y[VecSize];
for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim;
i += blockDim.x * VecSize) {
@ -123,37 +136,36 @@ __device__ void apply_k_rotary_emb_compute(
+ block_offset * x
+ x_offset;
copy_vector<scalar_t, VecSize>(x0, key + addr_offset);
copy_vector<scalar_t, VecSize>(x1, key + addr_offset + half_head_dim);
copy<T, VecSize>(key + addr_offset, x0);
copy<T, VecSize>(key + addr_offset + half_head_dim, x1);
#pragma unroll
for (int j = 0; j < VecSize; j++) {
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x0[j]) * cos_ptr[j * 32 + shard_offset] -
static_cast<m_scalar_t>(x1[j]) * sin_ptr[j * 32 + shard_offset]);
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x1[j]) * cos_ptr[j * 32 + shard_offset] +
static_cast<m_scalar_t>(x0[j]) * sin_ptr[j * 32 + shard_offset]);
out_x[j] = CastFunctor<MT, T>()(sub(mul(CastFunctor<T, MT>()(x0[j]), cos_ptr[j * 32 + shard_offset]),
mul(CastFunctor<T, MT>()(x1[j]), sin_ptr[j * 32 + shard_offset])));
out_y[j] = CastFunctor<MT, T>()(add(mul(CastFunctor<T, MT>()(x1[j]), cos_ptr[j * 32 + shard_offset]),
mul(CastFunctor<T, MT>()(x0[j]), sin_ptr[j * 32 + shard_offset])));
}
copy_vector<scalar_t, VecSize>(key_cache + target_id, out_x);
copy_vector<scalar_t, VecSize>(key_cache + target_id + half_head_dim * block_size,
out_y);
copy<T, CacheT, VecSize>(out_x, key_cache + target_id);
copy<T, CacheT, VecSize>(out_y, key_cache + target_id + half_head_dim * block_size);
}
// apply value memcopy
apply_kv_memcopy<scalar_t, VecSize>(
apply_kv_memcopy<T, CacheT, VecSize>(
value, value_cache, value_stride, token_id, block_id, kv_head_num * head_dim,
block_size, block_offset, head_dim, half_head_dim);
}
template<typename scalar_t, typename m_scalar_t, int VecSize>
template<typename T, typename MT, typename CacheT, int VecSize>
__global__ void rotary_embedding_and_cache_copy_kernel(
scalar_t* __restrict__ query,
scalar_t* __restrict__ key,
scalar_t* __restrict__ value,
const scalar_t* __restrict__ cos,
const scalar_t* __restrict__ sin,
scalar_t* __restrict__ key_cache,
scalar_t* __restrict__ value_cache,
T* __restrict__ query,
T* __restrict__ key,
T* __restrict__ value,
const T* __restrict__ cos,
const T* __restrict__ sin,
CacheT* __restrict__ key_cache,
CacheT* __restrict__ value_cache,
const int* __restrict__ sequence_lengths,
const int* __restrict__ block_tables,
const int64_t query_stride,
@ -176,26 +188,26 @@ __global__ void rotary_embedding_and_cache_copy_kernel(
extern __shared__ char shard_ptr[];
m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr;
m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
MT *cos_ptr = reinterpret_cast<MT*>(shard_ptr);
MT *sin_ptr = cos_ptr + half_shard_element_num;
// apply cos_sin memcopy
cos_sin_memory_access<scalar_t, m_scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
cos_sin_memory_access<T, MT, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
__syncthreads();
//compute query
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
apply_emb_rotary_compute<T, MT, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
//compute key and copy kv
apply_k_rotary_emb_compute<scalar_t, m_scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, x, half_head_dim, shard_block_size);
apply_k_rotary_emb_compute<T, MT, CacheT, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, x, half_head_dim, shard_block_size);
}
template<typename scalar_t, typename m_scalar_t, int VecSize>
template<typename T, typename MT, int VecSize>
__global__ void rotary_embedding_kernel(
scalar_t* __restrict__ query,
scalar_t* __restrict__ key,
const scalar_t* __restrict__ cos,
const scalar_t* __restrict__ sin,
T* __restrict__ query,
T* __restrict__ key,
const T* __restrict__ cos,
const T* __restrict__ sin,
const int64_t query_stride,
const int64_t key_stride,
const int64_t half_shard_element_num,
@ -211,29 +223,29 @@ __global__ void rotary_embedding_kernel(
extern __shared__ char shard_ptr[];
m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr;
m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
MT *cos_ptr = (MT*)shard_ptr;
MT *sin_ptr = cos_ptr + half_shard_element_num;
// apply cos_sin memcopy
cos_sin_memory_access<scalar_t, m_scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
cos_sin_memory_access<T, MT, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
__syncthreads();
//compute query
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
apply_emb_rotary_compute<T, MT, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
//compute key
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
apply_emb_rotary_compute<T, MT, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
}
#define ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(VEC_SIZE) \
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, VEC_SIZE><<<grid, block, shared_memory_size, stream>>>( \
query.data_ptr<scalar_t>(), \
key.data_ptr<scalar_t>(), \
value.data_ptr<scalar_t>(), \
cos.data_ptr<scalar_t>(), \
sin.data_ptr<scalar_t>(), \
key_cache.data_ptr<scalar_t>(), \
value_cache.data_ptr<scalar_t>(), \
rotary_embedding_and_cache_copy_kernel<T, MT, CacheT, VEC_SIZE><<<grid, block, shared_memory_size, stream>>>( \
reinterpret_cast<T*>(query.data_ptr()), \
reinterpret_cast<T*>(key.data_ptr()), \
reinterpret_cast<T*>(value.data_ptr()), \
reinterpret_cast<T*>(cos.data_ptr()), \
reinterpret_cast<T*>(sin.data_ptr()), \
reinterpret_cast<CacheT*>(key_cache.data_ptr()), \
reinterpret_cast<CacheT*>(value_cache.data_ptr()), \
sequence_lengths.data_ptr<int>(), \
block_tables.data_ptr<int>(), \
query_stride, \
@ -250,7 +262,7 @@ __global__ void rotary_embedding_kernel(
x); \
template<typename scalar_t, bool high_precision>
template<typename T, typename CacheT, bool high_precision>
void apply_rotary_embedding_and_cache_copy(
at::Tensor& query, // [num_tokens, head_num, head_dim]
at::Tensor& key, // [num_tokens, kv_head_num, head_dim]
@ -276,9 +288,9 @@ void apply_rotary_embedding_and_cache_copy(
int sin_stride = sin.stride(0);
int block_table_stride = block_tables.stride(0);
using m_scalar_t = typename colossalAI::common::ScalarTypeTrait<high_precision, scalar_t>::Type;
using MT = typename colossalAI::common::ScalarTypeTrait<high_precision, T>::Type;
int vec_size = get_vec_size<scalar_t>(query);
int vec_size = get_vec_size<T>(query);
if ((head_dim / 2) % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
@ -293,7 +305,7 @@ void apply_rotary_embedding_and_cache_copy(
dim3 grid(num_tokens);
dim3 block(std::min(thread_nums, 512));
int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size;
const int shared_memory_size = shard_element_num * sizeof(m_scalar_t);
const int shared_memory_size = shard_element_num * sizeof(MT);
switch (vec_size) {
case 1:
@ -313,7 +325,7 @@ void apply_rotary_embedding_and_cache_copy(
AT_CUDA_CHECK(cudaGetLastError());
}
template<typename scalar_t, bool high_precision>
template<typename T, bool high_precision>
void apply_rotary_embedding(
at::Tensor& query, // [total_tokens, head_num, head_dim]
at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
@ -330,9 +342,9 @@ void apply_rotary_embedding(
int cos_stride = cos.stride(0);
int sin_stride = sin.stride(0);
using m_scalar_t = typename colossalAI::common::ScalarTypeTrait<high_precision, scalar_t>::Type;
using MT = typename colossalAI::common::ScalarTypeTrait<high_precision, T>::Type;
int vec_size = get_vec_size<scalar_t>(query);
int vec_size = get_vec_size<T>(query);
if ((head_dim / 2) % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
@ -350,11 +362,11 @@ void apply_rotary_embedding(
switch (vec_size) {
case 1:
rotary_embedding_kernel<scalar_t, m_scalar_t, 1><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
rotary_embedding_kernel<T, MT, 1><<<grid, block, shard_element_num * sizeof(MT), stream>>>(
query.data_ptr<T>(),
key.data_ptr<T>(),
cos.data_ptr<T>(),
sin.data_ptr<T>(),
query_stride,
key_stride,
shard_element_num / 2,
@ -366,11 +378,11 @@ void apply_rotary_embedding(
);
break;
case 2:
rotary_embedding_kernel<scalar_t, m_scalar_t, 2><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
rotary_embedding_kernel<T, MT, 2><<<grid, block, shard_element_num * sizeof(MT), stream>>>(
query.data_ptr<T>(),
key.data_ptr<T>(),
cos.data_ptr<T>(),
sin.data_ptr<T>(),
query_stride,
key_stride,
shard_element_num / 2,
@ -382,11 +394,11 @@ void apply_rotary_embedding(
);
break;
case 4:
rotary_embedding_kernel<scalar_t, m_scalar_t, 4><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
rotary_embedding_kernel<T, MT, 4><<<grid, block, shard_element_num * sizeof(MT), stream>>>(
query.data_ptr<T>(),
key.data_ptr<T>(),
cos.data_ptr<T>(),
sin.data_ptr<T>(),
query_stride,
key_stride,
shard_element_num / 2,
@ -416,21 +428,81 @@ void rotary_embedding_and_cache_copy(
at::Tensor& block_tables, // [batch_size, max_seq_len]
bool high_precision)
{
DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(
high_precision,
query.scalar_type(),
"rotary_embedding_and_cache_copy",
apply_rotary_embedding_and_cache_copy<scalar_t, high_precision>(
query,
key,
value,
cos,
sin,
key_cache,
value_cache,
sequence_lengths,
block_tables
);)
#define _(T, CacheT, HIGH_PRECISION) \
apply_rotary_embedding_and_cache_copy<T, CacheT, HIGH_PRECISION>( \
query, \
key, \
value, \
cos, \
sin, \
key_cache, \
value_cache, \
sequence_lengths, \
block_tables);
if(key_cache.scalar_type() == at::ScalarType::Byte)
{
if(high_precision) {
switch (key.scalar_type())
{
case at::ScalarType::Float:
_(float, uint8_t, true)
break;
case at::ScalarType::Half:
_(half, uint8_t, true)
break;
case at::ScalarType::BFloat16:
_(__nv_bfloat16, uint8_t, true)
break;
}
}
else {
switch (key.scalar_type())
{
case at::ScalarType::Float:
_(float, uint8_t, false)
break;
case at::ScalarType::Half:
_(half, uint8_t, false)
break;
case at::ScalarType::BFloat16:
_(__nv_bfloat16, uint8_t, false)
break;
}
}
}
else
{
if(high_precision) {
switch (key.scalar_type())
{
case at::ScalarType::Float:
_(float, float, true)
break;
case at::ScalarType::Half:
_(half, half, true)
break;
case at::ScalarType::BFloat16:
_(__nv_bfloat16, __nv_bfloat16, true)
break;
}
}
else {
switch (key.scalar_type())
{
case at::ScalarType::Float:
_(float, float, false)
break;
case at::ScalarType::Half:
_(half, half, false)
break;
case at::ScalarType::BFloat16:
_(__nv_bfloat16, __nv_bfloat16, false)
break;
}
}
}
#undef _
}
void rotary_embedding(

View File

@ -11,6 +11,7 @@ namespace colossalAI {
namespace cuda {
namespace utils {
// Note(LiuYang): Depreciated
template <typename T, int vec_size>
__device__ __inline__ void copy_vector(T *dst, const T *src) {
using VT = typename common::VecTypeTrait<T, vec_size>::Type;
@ -26,6 +27,7 @@ __device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
*(reinterpret_cast<const float4 *>(src + 4));
}
// Note(LiuYang): Depreciated
template <typename T, int VecSize>
__device__ __inline__ void copy_zero_vector(T *dst) {
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
@ -36,13 +38,12 @@ template <typename SrcT, typename DstT, int vec_size>
__device__ __inline__ void copy(const SrcT *src, DstT *dst) {
using SrcVT = typename common::VecTypeTrait<SrcT, vec_size>::Type;
using DstVT = typename common::VecTypeTrait<DstT, vec_size>::Type;
// Note(LiuYang): Here static_cast can't be used for cast between two pointer
*(reinterpret_cast<DstVT *>(dst)) = funcs::CastFunctor<SrcVT, DstVT>()(
*(reinterpret_cast<const SrcVT *>(src)));
}
template <typename T, int vec_size>
__device__ __inline__ void copy<T, T, vec_size>(const T *src, T *dst) {
__device__ __inline__ void copy(const T *src, T *dst) {
using VT = typename common::VecTypeTrait<T, vec_size>::Type;
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
}