From 808ee6e4addccb51990398434547fa5df3c255b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Tue, 30 Apr 2024 11:26:36 +0800 Subject: [PATCH] [Inference/Feat] Feat quant kvcache step2 (#5674) --- extensions/csrc/funcs/cast_functor.h | 120 ++++++++++++++--- .../cuda/context_kv_cache_memcpy_kernel.cu | 126 ++++++++++++------ .../cuda/flash_decoding_attention_kernel.cu | 2 +- extensions/csrc/kernel/cuda/utils/vec_copy.h | 31 ++++- 4 files changed, 208 insertions(+), 71 deletions(-) diff --git a/extensions/csrc/funcs/cast_functor.h b/extensions/csrc/funcs/cast_functor.h index d33eece59..d9691d870 100644 --- a/extensions/csrc/funcs/cast_functor.h +++ b/extensions/csrc/funcs/cast_functor.h @@ -9,6 +9,7 @@ #endif #include +#include #include @@ -175,6 +176,16 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint8_t, uint16_t, DEVICE, STMTS_WRAPPER({ return res.x; })) +// half raw -> fp8 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint8_t, DEVICE, STMTS_WRAPPER({ + __half_raw tmp; + tmp.x = val; + __nv_fp8_storage_t res = + __nv_cvt_halfraw_to_fp8( + tmp, __NV_SATFINITE, __NV_E5M2); + return static_cast(res); + })) + // fp8x2 -> half2 raw COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint32_t, DEVICE, STMTS_WRAPPER({ union { @@ -222,6 +233,15 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint8_t, half, DEVICE, STMTS_WRAPPER({ return half(res); })) +// half -> fp8 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, uint8_t, DEVICE, STMTS_WRAPPER({ + __half_raw tmp(val); + __nv_fp8_storage_t res = + __nv_cvt_halfraw_to_fp8( + tmp, __NV_SATFINITE, __NV_E5M2); + return static_cast(res); + })) + // fp8x2 -> half2 COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, half2, DEVICE, STMTS_WRAPPER({ __half2_raw res = @@ -230,6 +250,15 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, half2, DEVICE, STMTS_WRAPPER({ return half2(res); })) +// half2 -> fp8x2 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, uint16_t, DEVICE, STMTS_WRAPPER({ + __half2_raw tmp(val); + __nv_fp8x2_storage_t res = + __nv_cvt_halfraw2_to_fp8x2( + tmp, __NV_SATFINITE, __NV_E5M2); + return static_cast(res); + })) + // fp8x4 -> half4 COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( uint32_t, dtype::half4, DEVICE, STMTS_WRAPPER({ @@ -242,6 +271,20 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( return res; })) +// half4 -> fp8x4 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + dtype::half4, uint32_t, DEVICE, STMTS_WRAPPER({ + half2 x, y; + x = val.x; + y = val.y; + uint16_t lo, hi; + lo = CastFunctor()(x); + hi = CastFunctor()(y); + uint32_t res; + asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(lo), "h"(hi)); + return res; + })) + // fp8x8 -> half8 COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( uint2, dtype::half8, DEVICE, STMTS_WRAPPER({ @@ -314,6 +357,14 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( return res; })) +// float -> fp8 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, uint8_t, DEVICE, STMTS_WRAPPER({ + __nv_fp8_storage_t res = + __nv_cvt_float_to_fp8( + val, __NV_SATFINITE, __NV_E5M2); + return static_cast(res); + })) + // fp8x2 -> float2 COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( uint16_t, float2, DEVICE, STMTS_WRAPPER({ @@ -328,6 +379,28 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( return make_float2(lof, hif); })) +// float2 -> fp8x2 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + float2, uint16_t, DEVICE, STMTS_WRAPPER({ + uint16_t tmp1 = + static_cast(CastFunctor()(val.x)); + uint16_t tmp2 = + static_cast(CastFunctor()(val.y)); + uint16_t res = (tmp1 << 8U) | tmp2; + return res; + })) + +// float4 -> fp8x4 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint32_t, DEVICE, STMTS_WRAPPER({ + uint32_t a, b, c, d; + a = CastFunctor()(val.x); + b = CastFunctor()(val.y); + c = CastFunctor()(val.z); + d = CastFunctor()(val.w); + return (a << 24U) | (b << 16U) | + (c << 8U) | d; + })) + // fp8x4 -> float4_ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( uint32_t, dtype::float4_, DEVICE, STMTS_WRAPPER({ @@ -338,6 +411,14 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( return res; })) +// fp8x4 -> float4 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint32_t, float4, DEVICE, STMTS_WRAPPER({ + dtype::float4_ tmp = CastFunctor()(val); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; + })) + // fp8x8 -> float8_ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( uint2, dtype::float8_, DEVICE, STMTS_WRAPPER({ @@ -352,16 +433,6 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( return res; })) -// half -> fp8 -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint8_t, DEVICE, STMTS_WRAPPER({ - __half_raw tmp; - tmp.x = val; - __nv_fp8_storage_t res = - __nv_cvt_halfraw_to_fp8( - tmp, __NV_SATFINITE, __NV_E5M2); - return static_cast(res); - })) - // bf16 -> fp8 COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, uint8_t, DEVICE, STMTS_WRAPPER({ @@ -376,19 +447,24 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, uint8_t, DEVICE, #endif })) -// float -> fp8 -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, uint8_t, DEVICE, STMTS_WRAPPER({ - __nv_fp8_storage_t res = - __nv_cvt_float_to_fp8( - val, __NV_SATFINITE, __NV_E5M2); - return static_cast(res); - })) - -// fp8x4 -> float4 +// bf162 -> fp8x2 COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - uint32_t, float4, DEVICE, STMTS_WRAPPER({ - dtype::float4_ tmp = CastFunctor()(val); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + __nv_bfloat162, uint16_t, DEVICE, STMTS_WRAPPER({ + uint16_t a = + static_cast(CastFunctor<__nv_bfloat16, uint8_t>()(val.x)); + uint16_t b = + static_cast(CastFunctor<__nv_bfloat16, uint8_t>()(val.y)); + return (a << 8U) | b; + })) + +// bf164 -> fp8x4 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + dtype::bfloat164, uint32_t, DEVICE, STMTS_WRAPPER({ + uint32_t res; + uint16_t a, b; + a = CastFunctor<__nv_bfloat162, uint16_t>()(val.x); + b = CastFunctor<__nv_bfloat162, uint16_t>()(val.y); + asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(a), "h"(b)); return res; })) diff --git a/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu b/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu index 9b3a8261e..6e849b074 100644 --- a/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu @@ -4,16 +4,17 @@ #include "utils/vec_copy.h" #include "common/micros.h" -using colossalAI::cuda::utils::copy_vector; using colossalAI::cuda::utils::get_vec_size; +using colossalAI::cuda::utils::copy; +using colossalAI::funcs::CastFunctor; -template +template __global__ void context_kv_cache_memcpy_kernel( - const scalar_t* __restrict__ key, - const scalar_t* __restrict__ value, - scalar_t* __restrict__ key_cache, - scalar_t* __restrict__ value_cache, + const T* __restrict__ key, + const T* __restrict__ value, + CacheT* __restrict__ key_cache, + CacheT* __restrict__ value_cache, const int* __restrict__ sequence_lengths, const int* __restrict__ cu_seqlens, const int* __restrict__ block_tables, @@ -54,8 +55,8 @@ __global__ void context_kv_cache_memcpy_kernel( + head_id * block_size * head_dim + block_offset * head_dim + head_offset; - copy_vector(key_cache + target_id, key + key_src_id); - copy_vector(value_cache + target_id, value + value_src_id); + copy(key + key_src_id, key_cache + target_id); + copy(value + value_src_id, value_cache + target_id); } // tail process @@ -69,22 +70,22 @@ __global__ void context_kv_cache_memcpy_kernel( + head_id * block_size * head_dim + block_offset * head_dim + head_offset; - key_cache[target_id] = key[key_src_id]; - value_cache[target_id] = value[value_src_id]; + key_cache[target_id] = CastFunctor()(key[key_src_id]); + value_cache[target_id] = CastFunctor()(value[value_src_id]); } } } -template +template void apply_context_kv_cache_memcpy( - at::Tensor& key, // [num_tokens, head_num, head_dim] - at::Tensor& value, // [num_tokens, head_num, head_dim] - at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] - at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] - at::Tensor& sequence_lengths, // [batch_size] - at::Tensor& cu_seqlens, // [batch_size + 1] - at::Tensor& block_tables, // [batch_size, max_seq_len] + torch::Tensor& key, // [num_tokens, head_num, head_dim] + torch::Tensor& value, // [num_tokens, head_num, head_dim] + torch::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + torch::Tensor& sequence_lengths, // [batch_size] + torch::Tensor& cu_seqlens, // [batch_size + 1] + torch::Tensor& block_tables, // [batch_size, max_seq_len] int max_seq_len_in_batch) { int num_tokens = key.size(0); @@ -97,7 +98,7 @@ void apply_context_kv_cache_memcpy( int64_t value_stride = value.stride(0); int block_table_stride = block_tables.stride(0); - int vec_size = get_vec_size(key); + int vec_size = get_vec_size(key); bool aligned = true; if (head_dim % vec_size != 0) { @@ -112,11 +113,11 @@ void apply_context_kv_cache_memcpy( #define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \ do { \ - context_kv_cache_memcpy_kernel<<>>( \ - key.data_ptr(), \ - value.data_ptr(), \ - key_cache.data_ptr(), \ - value_cache.data_ptr(), \ + context_kv_cache_memcpy_kernel<<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ sequence_lengths.data_ptr(), \ cu_seqlens.data_ptr(), \ block_tables.data_ptr(), \ @@ -161,26 +162,63 @@ void apply_context_kv_cache_memcpy( } void context_kv_cache_memcpy( - at::Tensor& key, // [num_tokens, head_num, head_dim] - at::Tensor& value, // [num_tokens, head_num, head_dim] - at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] - at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] - at::Tensor& sequence_lengths, // [batch_size] - at::Tensor& cu_seqlens, // [batch_size + 1] - at::Tensor& block_tables, // [batch_size, max_seq_len] + torch::Tensor& key, // [num_tokens, head_num, head_dim] + torch::Tensor& value, // [num_tokens, head_num, head_dim] + torch::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + torch::Tensor& sequence_lengths, // [batch_size] + torch::Tensor& cu_seqlens, // [batch_size + 1] + torch::Tensor& block_tables, // [batch_size, max_seq_len] int max_seq_len_in_batch) { - DISPATCH_FLOAT_HALF_AND_BFLOAT( - key.scalar_type(), - "context_kv_cache_memcpy", - apply_context_kv_cache_memcpy( - key, - value, - key_cache, - value_cache, - sequence_lengths, - cu_seqlens, - block_tables, - max_seq_len_in_batch - );) + + TORCH_CHECK(key.scalar_type() == at::ScalarType::Float || key.scalar_type() == at::ScalarType::Half || key.scalar_type() == at::ScalarType::BFloat16, + "Dtype of key should be float, half or bfloat16!"); + TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == key.scalar_type(), + "Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!"); + + +#define _(T, CacheT) \ + apply_context_kv_cache_memcpy( \ + key, \ + value, \ + key_cache, \ + value_cache, \ + sequence_lengths, \ + cu_seqlens, \ + block_tables, \ + max_seq_len_in_batch \ + ) + + if(key_cache.scalar_type() == at::ScalarType::Byte) + { + switch (key.scalar_type()) + { + case at::ScalarType::Float: + _(float, uint8_t); + break; + case at::ScalarType::Half: + _(half, uint8_t); + break; + case at::ScalarType::BFloat16: + _(__nv_bfloat16, uint8_t); + break; + } + } + else + { + switch (key.scalar_type()) + { + case at::ScalarType::Float: + _(float, float); + break; + case at::ScalarType::Half: + _(half, half); + break; + case at::ScalarType::BFloat16: + _(__nv_bfloat16, __nv_bfloat16); + break; + } + } +#undef _ } diff --git a/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu b/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu index 9e933ff2a..ac5e40725 100644 --- a/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu +++ b/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu @@ -372,7 +372,7 @@ void flash_decoding_attention( TORCH_CHECK(query.scalar_type() == at::ScalarType::Float || query.scalar_type() == at::ScalarType::Half || query.scalar_type() == at::ScalarType::BFloat16, "Dtype of query should be float, half or bfloat16!"); - TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == key_cache.scalar_type(), + TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == query.scalar_type(), "Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!"); if(key_cache.scalar_type() == at::ScalarType::Byte) diff --git a/extensions/csrc/kernel/cuda/utils/vec_copy.h b/extensions/csrc/kernel/cuda/utils/vec_copy.h index 8fe4e113c..ad98361dd 100644 --- a/extensions/csrc/kernel/cuda/utils/vec_copy.h +++ b/extensions/csrc/kernel/cuda/utils/vec_copy.h @@ -11,10 +11,9 @@ namespace colossalAI { namespace cuda { namespace utils { -template +template __device__ __inline__ void copy_vector(T *dst, const T *src) { - using VT = typename common::VecTypeTrait::Type; - // Note(LiuYang): Here static_cast can't be used for cast between two pointer + using VT = typename common::VecTypeTrait::Type; *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); } @@ -33,9 +32,33 @@ __device__ __inline__ void copy_zero_vector(T *dst) { *(reinterpret_cast(dst)) = funcs::CastFunctor()(0.0f); } +template +__device__ __inline__ void copy(const SrcT *src, DstT *dst) { + using SrcVT = typename common::VecTypeTrait::Type; + using DstVT = typename common::VecTypeTrait::Type; + // Note(LiuYang): Here static_cast can't be used for cast between two pointer + *(reinterpret_cast(dst)) = funcs::CastFunctor()( + *(reinterpret_cast(src))); +} + +template +__device__ __inline__ void copy(const T *src, T *dst) { + using VT = typename common::VecTypeTrait::Type; + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); +} + +template <> +__device__ __inline__ void copy(const float *src, float *dst) { + // Since the maximum memory alignment length is 128 bits, we choose float4 + // here. + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); + *(reinterpret_cast(dst + 4)) = + *(reinterpret_cast(src + 4)); +} + template int get_vec_size(const torch::Tensor &tensor) { - uint64_t address = reinterpret_cast(tensor.data_ptr()); + uint64_t address = reinterpret_cast(tensor.data_ptr()); const int max_aligned_size = 128; const int dtype_size = sizeof(T) * 8;