mirror of https://github.com/hpcaitech/ColossalAI
[Inference/Feat] Add quant kvcache support for decode_kv_cache_memcpy (#5686)
parent
db7b3051f4
commit
1ace1065e6
|
@ -2,17 +2,21 @@
|
|||
#include <torch/extension.h>
|
||||
|
||||
#include "utils/vec_copy.h"
|
||||
#include "funcs/cast_functor.h"
|
||||
#include "common/micros.h"
|
||||
|
||||
using colossalAI::cuda::utils::copy_vector;
|
||||
using colossalAI::cuda::utils::get_vec_size;
|
||||
using colossalAI::cuda::utils::copy;
|
||||
using colossalAI::funcs::CastFunctor;
|
||||
|
||||
template<typename scalar_t, bool Aligned, int VecSize>
|
||||
|
||||
template<typename T, typename CacheT, bool Aligned, int VecSize>
|
||||
__global__ void decode_kv_cache_memcpy_kernel(
|
||||
const scalar_t* __restrict__ key,
|
||||
const scalar_t* __restrict__ value,
|
||||
scalar_t* __restrict__ key_cache,
|
||||
scalar_t* __restrict__ value_cache,
|
||||
const T* __restrict__ key,
|
||||
const T* __restrict__ value,
|
||||
CacheT* __restrict__ key_cache,
|
||||
CacheT* __restrict__ value_cache,
|
||||
const int* __restrict__ sequence_lengths,
|
||||
const int* __restrict__ block_tables,
|
||||
const int head_num,
|
||||
|
@ -52,8 +56,8 @@ __global__ void decode_kv_cache_memcpy_kernel(
|
|||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
copy_vector<scalar_t, VecSize>(key_cache + target_key_id, key + key_src_id);
|
||||
copy_vector<scalar_t, VecSize>(value_cache + target_value_id, value + value_src_id);
|
||||
copy<T, CacheT, VecSize>(key + key_src_id, key_cache + target_key_id);
|
||||
copy<T, CacheT, VecSize>(value + value_src_id, value_cache + target_value_id);
|
||||
}
|
||||
|
||||
if (!Aligned) {
|
||||
|
@ -73,14 +77,14 @@ __global__ void decode_kv_cache_memcpy_kernel(
|
|||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
key_cache[target_key_id] = key[key_src_id];
|
||||
value_cache[target_value_id] = value[value_src_id];
|
||||
key_cache[target_key_id] = CastFunctor<T, CacheT>()(key[key_src_id]);
|
||||
value_cache[target_value_id] = CastFunctor<T, CacheT>()(value[value_src_id]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
template<typename T, typename CacheT>
|
||||
void apply_decode_kv_cache_memcpy(
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
|
@ -99,7 +103,7 @@ void apply_decode_kv_cache_memcpy(
|
|||
int64_t value_stride = value.stride(0);
|
||||
int block_table_stride = block_tables.stride(0);
|
||||
|
||||
int vec_size = get_vec_size<scalar_t>(key);
|
||||
int vec_size = get_vec_size<T>(key);
|
||||
|
||||
bool aligned = true;
|
||||
if (head_dim % vec_size != 0) {
|
||||
|
@ -114,11 +118,11 @@ void apply_decode_kv_cache_memcpy(
|
|||
|
||||
#define DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \
|
||||
do { \
|
||||
decode_kv_cache_memcpy_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
|
||||
key.data_ptr<scalar_t>(), \
|
||||
value.data_ptr<scalar_t>(), \
|
||||
key_cache.data_ptr<scalar_t>(), \
|
||||
value_cache.data_ptr<scalar_t>(), \
|
||||
decode_kv_cache_memcpy_kernel<T, CacheT, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<T*>(key.data_ptr()), \
|
||||
reinterpret_cast<T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CacheT*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CacheT*>(value_cache.data_ptr()), \
|
||||
sequence_lengths.data_ptr<int>(), \
|
||||
block_tables.data_ptr<int>(), \
|
||||
head_num, \
|
||||
|
@ -168,15 +172,46 @@ void decode_kv_cache_memcpy(
|
|||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||
{
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
key.scalar_type(),
|
||||
"decode_kv_cache_memcpy",
|
||||
apply_decode_kv_cache_memcpy<scalar_t>(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
sequence_lengths,
|
||||
block_tables
|
||||
);)
|
||||
|
||||
#define _(T, CacheT) \
|
||||
apply_decode_kv_cache_memcpy<T, CacheT>( \
|
||||
key, \
|
||||
value, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
sequence_lengths, \
|
||||
block_tables \
|
||||
)
|
||||
|
||||
if(key_cache.scalar_type() == at::ScalarType::Byte)
|
||||
{
|
||||
switch (key.scalar_type())
|
||||
{
|
||||
case at::ScalarType::Float:
|
||||
_(float, uint8_t);
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
_(half, uint8_t);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
_(__nv_bfloat16, uint8_t);
|
||||
break;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
switch (key.scalar_type())
|
||||
{
|
||||
case at::ScalarType::Float:
|
||||
_(float, float);
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
_(half, half);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
_(__nv_bfloat16, __nv_bfloat16);
|
||||
break;
|
||||
}
|
||||
}
|
||||
#undef _
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue