[Inference/Feat] Add kvcache quantization support for FlashDecoding (#5656)

pull/5674/head
傅剑寒 2024-04-26 19:40:37 +08:00 committed by GitHub
parent 5be590b99e
commit 8ccb6714e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 482 additions and 174 deletions

View File

@ -5,6 +5,7 @@
#include <cuda_fp16.h>
#endif
#include <ATen/ATen.h>
#include <stdint.h>
#include "common/data_type.h"
@ -27,6 +28,7 @@ struct FloatVecTypeTrait {};
VEC_TYPE_TRAITS_SPECIALIZATION(T, 1, T, typename T)
#if defined(COLOSSAL_WITH_CUDA)
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 1, __nv_bfloat16)
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 2, __nv_bfloat162)
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 4, float2)
@ -35,18 +37,19 @@ VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 1, half)
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 2, half2)
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 4, float2)
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 8, float4)
VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2)
VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4)
VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8_)
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 2, half)
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 4, half2)
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 8, float2)
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 2, uint16_t)
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 4, uint32_t)
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 8, uint2)
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 2, __nv_bfloat162);
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 4, dtype::bfloat164);
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 8, dtype::bfloat168);
VEC_TYPE_TRAITS_SPECIALIZATION(half, 2, half2);
VEC_TYPE_TRAITS_SPECIALIZATION(half, 4, dtype::half4);
VEC_TYPE_TRAITS_SPECIALIZATION(half, 8, dtype::half8);
VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2)
VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4)
VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8_)
#endif /* defined(COLOSSAL_WITH_CUDA) */
#undef VEC_TYPE_TRAITS_SPECIALIZATION

View File

@ -4,9 +4,12 @@
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#endif
#include <assert.h>
#include <functional>
#include "common/data_type.h"
@ -23,141 +26,421 @@ struct CastFunctor : public std::unary_function<From, To> {
HOSTDEVICE To operator()(From val) { return static_cast<To>(val); }
};
#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, STMTS, \
FUNCTION_MODIFIER) \
template <> \
struct CastFunctor<FROM, TO> : public std::unary_function<FROM, TO> { \
FUNCTION_MODIFIER TO operator()(FROM val) STMTS \
#define STMTS_WRAPPER(...) __VA_ARGS__
#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, FUNCTION_MODIFIER, \
STMTS) \
template <> \
struct CastFunctor<FROM, TO> : public std::unary_function<FROM, TO> { \
FUNCTION_MODIFIER TO operator()(FROM val) STMTS \
};
#if defined(COLOSSAL_WITH_CUDA)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
int2, float2, { return make_float2(val.x, val.y); }, DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
float, float2, { return make_float2(val, val); }, DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(int2, float2, DEVICE, STMTS_WRAPPER({
return make_float2(val.x, val.y);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, float2, DEVICE, STMTS_WRAPPER({
return make_float2(val, val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
half2, float2, { return __half22float2(val); }, DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
float2, half2, { return __float22half2_rn(val); }, DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
float, half, { return __float2half_rn(val); }, DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
float, half2, { return __float2half2_rn(val); }, DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
half, half2, { return __half2half2(val); }, DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
half, float, { return __half2float(val); }, DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
float4, dtype::half4,
{
dtype::half4 dst;
dst.x = __floats2half2_rn(val.x, val.y);
dst.y = __floats2half2_rn(val.z, val.w);
return dst;
},
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
dtype::float4_, dtype::half4,
{
dtype::half4 dst;
dst.x = __float22half2_rn(val.x);
dst.y = __float22half2_rn(val.y);
return dst;
},
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
dtype::float8_, dtype::half8,
{
dtype::half8 dst;
dst.x = __float22half2_rn(val.x);
dst.y = __float22half2_rn(val.y);
dst.z = __float22half2_rn(val.z);
dst.w = __float22half2_rn(val.w);
return dst;
},
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, float2, DEVICE, STMTS_WRAPPER({
return __half22float2(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, half2, DEVICE, STMTS_WRAPPER({
return __float22half2_rn(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half, DEVICE, STMTS_WRAPPER({
return __float2half_rn(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half2, DEVICE, STMTS_WRAPPER({
return __float2half2_rn(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, half2, DEVICE, STMTS_WRAPPER({
return __half2half2(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, float, DEVICE, STMTS_WRAPPER({
return __half2float(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::half4, DEVICE,
STMTS_WRAPPER({
dtype::half4 dst;
dst.x = __floats2half2_rn(val.x, val.y);
dst.y = __floats2half2_rn(val.z, val.w);
return dst;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, dtype::half4, DEVICE,
STMTS_WRAPPER({
dtype::half4 dst;
dst.x = __float22half2_rn(val.x);
dst.y = __float22half2_rn(val.y);
return dst;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8_, dtype::half8, DEVICE,
STMTS_WRAPPER({
dtype::half8 dst;
dst.x = __float22half2_rn(val.x);
dst.y = __float22half2_rn(val.y);
dst.z = __float22half2_rn(val.z);
dst.w = __float22half2_rn(val.w);
return dst;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
float, __nv_bfloat162, { return __float2bfloat162_rn(val); }, DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
float, __nv_bfloat16, { return __float2bfloat16_rn(val); }, DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
float4, dtype::bfloat164,
{
dtype::bfloat164 dst;
dst.x = __floats2bfloat162_rn(val.x, val.y);
dst.y = __floats2bfloat162_rn(val.z, val.w);
return dst;
},
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat162, DEVICE,
STMTS_WRAPPER({
return __float2bfloat162_rn(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat16, DEVICE,
STMTS_WRAPPER({
return __float2bfloat16_rn(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::bfloat164, DEVICE,
STMTS_WRAPPER({
dtype::bfloat164 dst;
dst.x =
__floats2bfloat162_rn(val.x, val.y);
dst.y =
__floats2bfloat162_rn(val.z, val.w);
return dst;
}))
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
__nv_bfloat16, __nv_bfloat162, { return __bfloat162bfloat162(val); },
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
__nv_bfloat162, float2, { return __bfloat1622float2(val); }, DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
float2, __nv_bfloat162, { return __float22bfloat162_rn(val); }, DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
dtype::float4_, dtype::bfloat164,
{
dtype::bfloat164 dst;
dst.x = __float22bfloat162_rn(val.x);
dst.y = __float22bfloat162_rn(val.y);
return dst;
},
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
dtype::float8_, dtype::bfloat168,
{
dtype::bfloat168 dst;
dst.x = __float22bfloat162_rn(val.x);
dst.y = __float22bfloat162_rn(val.y);
dst.z = __float22bfloat162_rn(val.z);
dst.w = __float22bfloat162_rn(val.w);
return dst;
},
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat162, DEVICE,
STMTS_WRAPPER({
return __bfloat162bfloat162(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat162, float2, DEVICE,
STMTS_WRAPPER({
return __bfloat1622float2(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, __nv_bfloat162, DEVICE,
STMTS_WRAPPER({
return __float22bfloat162_rn(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, dtype::bfloat164, DEVICE,
STMTS_WRAPPER({
dtype::bfloat164 dst;
dst.x = __float22bfloat162_rn(val.x);
dst.y = __float22bfloat162_rn(val.y);
return dst;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8_, dtype::bfloat168, DEVICE,
STMTS_WRAPPER({
dtype::bfloat168 dst;
dst.x = __float22bfloat162_rn(val.x);
dst.y = __float22bfloat162_rn(val.y);
dst.z = __float22bfloat162_rn(val.z);
dst.w = __float22bfloat162_rn(val.w);
return dst;
}))
#else
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat162, DEVICE,
STMTS_WRAPPER({
__nv_bfloat162 dst;
dst.x = val;
dst.y = val;
return dst;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat162, float2, DEVICE,
STMTS_WRAPPER({
return make_float2(__low2float(val),
__high2float(val));
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, __nv_bfloat162, DEVICE,
STMTS_WRAPPER({
return __floats2bfloat162_rn(val.x,
val.y);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
__nv_bfloat16, __nv_bfloat162,
{
__nv_bfloat162 dst;
dst.x = val;
dst.y = val;
return dst;
},
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
__nv_bfloat162, float2,
{ return make_float2(__low2float(val), __high2float(val)); }, DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
float2, __nv_bfloat162, { return __floats2bfloat162_rn(val.x, val.y); },
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
dtype::float4_, dtype::bfloat164,
{
dtype::float4_, dtype::bfloat164, DEVICE, STMTS_WRAPPER({
dtype::bfloat164 dst;
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
return dst;
},
DEVICE)
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
dtype::float8_, dtype::bfloat168,
{
dtype::float8_, dtype::bfloat168, DEVICE, STMTS_WRAPPER({
dtype::bfloat168 dst;
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
dst.z = __floats2bfloat162_rn(val.z.x, val.z.y);
dst.w = __floats2bfloat162_rn(val.w.x, val.w.y);
return dst;
},
DEVICE)
}))
#endif /* defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 */
// quant utils
// fp8 -> half raw
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint8_t, uint16_t, DEVICE, STMTS_WRAPPER({
__half_raw res = __nv_cvt_fp8_to_halfraw(
val, __NV_E5M2);
return res.x;
}))
// fp8x2 -> half2 raw
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint32_t, DEVICE, STMTS_WRAPPER({
union {
uint16_t u16[2];
uint32_t u32;
} tmp;
__half2_raw res =
__nv_cvt_fp8x2_to_halfraw2(
val, __NV_E5M2);
tmp.u16[0] = res.x;
tmp.u16[1] = res.y;
return tmp.u32;
}))
// fp8x4 -> half2x2 raw
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint32_t, uint2, DEVICE, STMTS_WRAPPER({
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] =
CastFunctor<uint16_t, uint32_t>()(static_cast<uint16_t>(val));
tmp.u32[1] =
CastFunctor<uint16_t, uint32_t>()(static_cast<uint16_t>(val >> 16U));
return tmp.u32x2;
}))
// fp8x8 -> half2x4 raw
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint2, uint4, DEVICE, STMTS_WRAPPER({
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = CastFunctor<uint32_t, uint2>()(val.x);
tmp.u64[1] = CastFunctor<uint32_t, uint2>()(val.y);
return tmp.u64x2;
}))
// fp8 -> half
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint8_t, half, DEVICE, STMTS_WRAPPER({
__half_raw res = __nv_cvt_fp8_to_halfraw(
val, __NV_E5M2);
return half(res);
}))
// fp8x2 -> half2
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, half2, DEVICE, STMTS_WRAPPER({
__half2_raw res =
__nv_cvt_fp8x2_to_halfraw2(
val, __NV_E5M2);
return half2(res);
}))
// fp8x4 -> half4
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint32_t, dtype::half4, DEVICE, STMTS_WRAPPER({
half2 tmp1, tmp2;
tmp1 = CastFunctor<uint16_t, half2>()(static_cast<uint16_t>(val));
tmp2 = CastFunctor<uint16_t, half2>()(static_cast<uint16_t>(val >> 16U));
dtype::half4 res;
res.x = tmp1;
res.y = tmp2;
return res;
}))
// fp8x8 -> half8
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint2, dtype::half8, DEVICE, STMTS_WRAPPER({
dtype::half4 tmp1, tmp2;
tmp1 = CastFunctor<uint32_t, dtype::half4>()(val.x);
tmp2 = CastFunctor<uint32_t, dtype::half4>()(val.y);
dtype::half8 res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}))
// fp8 -> __nv_bfloat16
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint8_t, __nv_bfloat16, DEVICE, STMTS_WRAPPER({
// Note there is no direct convert function from fp8 to bf16.
// fp8 -> half
__half_raw res = __nv_cvt_fp8_to_halfraw(val, __NV_E5M2);
// half -> float -> bf16
float tmp;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(tmp) : "h"(res.x));
return __float2bfloat16(tmp);
}))
// fp8x2 -> __nv_bfloat162
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint16_t, __nv_bfloat162, DEVICE, STMTS_WRAPPER({
__nv_bfloat162 res;
res.x = CastFunctor<uint8_t, __nv_bfloat16>()(static_cast<uint8_t>(val));
res.y = CastFunctor<uint8_t, __nv_bfloat16>()(
static_cast<uint8_t>(val >> 8U));
return res;
}))
// fp8x4 -> bfloat164
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint32_t, dtype::bfloat164, DEVICE, STMTS_WRAPPER({
dtype::bfloat164 res;
res.x =
CastFunctor<uint16_t, __nv_bfloat162>()(static_cast<uint16_t>(val));
res.y = CastFunctor<uint16_t, __nv_bfloat162>()(
static_cast<uint16_t>(val >> 16U));
return res;
}))
// fp8x8 -> bfloat168
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint2, dtype::bfloat168, DEVICE, STMTS_WRAPPER({
dtype::bfloat164 tmp1, tmp2;
tmp1 = CastFunctor<uint32_t, dtype::bfloat164>()(val.x);
tmp2 = CastFunctor<uint32_t, dtype::bfloat164>()(val.y);
dtype::bfloat168 res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}))
// fp8 -> float
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint8_t, float, DEVICE, STMTS_WRAPPER({
// fp8 -> half
uint16_t tmp = CastFunctor<uint8_t, uint16_t>()(val);
// half -> float
float res;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(res) : "h"(tmp));
return res;
}))
// fp8x2 -> float2
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint16_t, float2, DEVICE, STMTS_WRAPPER({
// fp8x2 -> half2
uint32_t tmp = CastFunctor<uint16_t, uint32_t>()(val);
// half2 -> float2
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(tmp));
float lof, hif;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(lof) : "h"(lo));
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(hif) : "h"(hi));
return make_float2(lof, hif);
}))
// fp8x4 -> float4_
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint32_t, dtype::float4_, DEVICE, STMTS_WRAPPER({
dtype::float4_ res;
res.x = CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val));
res.y =
CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val >> 16U));
return res;
}))
// fp8x8 -> float8_
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint2, dtype::float8_, DEVICE, STMTS_WRAPPER({
dtype::float4_ tmp1, tmp2;
tmp1 = CastFunctor<uint32_t, dtype::float4_>()(val.x);
tmp2 = CastFunctor<uint32_t, dtype::float4_>()(val.y);
dtype::float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}))
// half -> fp8
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint8_t, DEVICE, STMTS_WRAPPER({
__half_raw tmp;
tmp.x = val;
__nv_fp8_storage_t res =
__nv_cvt_halfraw_to_fp8(
tmp, __NV_SATFINITE, __NV_E5M2);
return static_cast<uint8_t>(res);
}))
// bf16 -> fp8
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, uint8_t, DEVICE,
STMTS_WRAPPER({
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
__nv_fp8_storage_t res =
__nv_cvt_bfloat16raw_to_fp8(
__nv_bfloat16_raw(val),
__NV_SATFINITE, __NV_E5M2);
return static_cast<uint8_t>(res);
#endif
}))
// float -> fp8
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, uint8_t, DEVICE, STMTS_WRAPPER({
__nv_fp8_storage_t res =
__nv_cvt_float_to_fp8(
val, __NV_SATFINITE, __NV_E5M2);
return static_cast<uint8_t>(res);
}))
// fp8x4 -> float4
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint32_t, float4, DEVICE, STMTS_WRAPPER({
dtype::float4_ tmp = CastFunctor<uint32_t, dtype::float4_>()(val);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, uint32_t, DEVICE, STMTS_WRAPPER({
union {
half2 float16;
uint32_t uint32;
};
float16 = __float22half2_rn(val);
return uint32;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, uint2, DEVICE,
STMTS_WRAPPER({
uint2 b;
float2 c;
c.x = val.x.x;
c.y = val.x.y;
b.x = CastFunctor<float2, uint32_t>()(c);
c.x = val.y.x;
c.y = val.y.y;
b.y = CastFunctor<float2, uint32_t>()(c);
return b;
}))
// float4_ -> float4
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, float4, DEVICE,
STMTS_WRAPPER({
float4 b;
b.x = val.x.x;
b.y = val.x.y;
b.z = val.y.x;
b.w = val.y.y;
return b;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
dtype::float8_, uint4, DEVICE, STMTS_WRAPPER({
uint4 b;
b.x = CastFunctor<float2, uint32_t>()(val.x);
b.y = CastFunctor<float2, uint32_t>()(val.y);
b.z = CastFunctor<float2, uint32_t>()(val.z);
b.w = CastFunctor<float2, uint32_t>()(val.w);
return b;
}))
#endif /* defined(COLOSSAL_WITH_CUDA) */
#undef STMTS_WRAPPER
#undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION
} // namespace funcs
} // namespace colossalAI

View File

@ -15,21 +15,6 @@
namespace colossalAI {
namespace funcs {
template <typename T>
inline __device__ void zero(T& dst) {
constexpr int WORDS = sizeof(T) / 4;
union {
T raw;
uint32_t words[WORDS];
} tmp;
#pragma unroll
for (int ii = 0; ii < WORDS; ii++) {
tmp.words[ii] = 0u;
}
dst = tmp.raw;
}
// Note(LiuYang): As a retrieved table to check which operation is supported
// already
enum class UnaryOpType { kLog2Ceil = 0, kAbs, kSum };

View File

@ -174,13 +174,13 @@ void context_kv_cache_memcpy(
key.scalar_type(),
"context_kv_cache_memcpy",
apply_context_kv_cache_memcpy<scalar_t>(
key,
value,
key_cache,
value_cache,
sequence_lengths,
cu_seqlens,
block_tables,
max_seq_len_in_batch
);)
key,
value,
key_cache,
value_cache,
sequence_lengths,
cu_seqlens,
block_tables,
max_seq_len_in_batch
);)
}

View File

@ -5,7 +5,6 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <stdio.h>
#include "common/micros.h"
#include "funcs/cast_functor.h"
@ -34,11 +33,25 @@ constexpr unsigned int nextHighestPowerOf2(unsigned int v) {
return v;
}
template <typename T>
inline __device__ void zero(T& dst) {
constexpr int WORDS = sizeof(T) / 4;
union {
T raw;
uint32_t words[WORDS];
} tmp;
#pragma unroll
for (int ii = 0; ii < WORDS; ii++) {
tmp.words[ii] = 0u;
}
dst = tmp.raw;
}
using colossalAI::funcs::BinaryOpType;
using colossalAI::funcs::CastFunctor;
using colossalAI::funcs::TernaryOpFunctor;
using colossalAI::funcs::TernaryOpType;
using colossalAI::funcs::zero;
using colossalAI::common::VecTypeTrait;
using colossalAI::common::FloatVecTypeTrait;
using namespace colossalAI::cuda::attention;
@ -84,10 +97,12 @@ __global__ void flash_decoding_attention_kernel(
constexpr int NUM_ROWS_PER_ROUNDS = MIN(WARP_SIZE / NUM_THREADS_PER_X, BLOCK_SIZE);
constexpr int NUM_VECS_PER_THREAD = NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN / WARP_SIZE;
using K_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
using V_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
using L_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
using Float_vec = typename FloatVecTypeTrait<L_vec>::Type;
using KVecT = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
using VVecT = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
using KQuantVecT = typename VecTypeTrait<cache_t, VEC_SIZE>::Type;
using VQuantVecT = typename VecTypeTrait<cache_t, VEC_SIZE>::Type;
using LVecT = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
using FloatVecT = typename FloatVecTypeTrait<LVecT>::Type;
const int context_len = context_lens[seq_idx];
const int thread_group_offset = lane % NUM_THREADS_PER_X;
@ -119,18 +134,18 @@ __global__ void flash_decoding_attention_kernel(
scalar_t* q_shared_ptr = reinterpret_cast<scalar_t*>(q_shared);
// each warp access a whole block
K_vec q_vecs[NUM_VECS_PER_THREAD];
KVecT q_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int idx = lane, i = 0; idx < NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN; idx += WARP_SIZE, i += 1) {
const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS;
const int offset1 = idx % NUM_THREADS_PER_X;
q_vecs[i] = *reinterpret_cast<K_vec*>(q_shared_ptr + offset0 * x + offset1 * VEC_SIZE);
q_vecs[i] = *reinterpret_cast<KVecT*>(q_shared_ptr + offset0 * x + offset1 * VEC_SIZE);
}
for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
const int64_t physical_block_number = static_cast<int64_t>(block_table_shared[block_idx]);
K_vec k_vecs[NUM_VECS_PER_THREAD];
KVecT k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < BLOCK_SIZE; i += NUM_ROWS_PER_ROUNDS) {
@ -142,7 +157,7 @@ __global__ void flash_decoding_attention_kernel(
const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS;
const int offset1 = (idx / NUM_THREADS_PER_X) % NUM_ROWS_PER_ROUNDS;
const int offset2 = idx % NUM_THREADS_PER_X;
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset0 * BLOCK_SIZE * x + offset1 * x + offset2 * VEC_SIZE);
k_vecs[j] = CastFunctor<KQuantVecT, KVecT>()(*reinterpret_cast<const KQuantVecT*>(k_ptr + offset0 * BLOCK_SIZE * x + offset1 * x + offset2 * VEC_SIZE));
}
float qk = scale * Qk_dot<scalar_t, NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X, NUM_THREADS_PER_X>::dot(q_vecs, k_vecs);
@ -174,13 +189,13 @@ __global__ void flash_decoding_attention_kernel(
}
__syncthreads();
Float_vec accs[NUM_ROUNDS_PER_TOKEN];
FloatVecT accs[NUM_ROUNDS_PER_TOKEN];
#pragma unroll
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
zero(accs[i]);
}
V_vec zero_value;
VVecT zero_value;
zero(zero_value);
for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
const int64_t physical_block_number = static_cast<int64_t>(block_table_shared[block_idx]);
@ -193,11 +208,11 @@ __global__ void flash_decoding_attention_kernel(
+ kv_head_idx * kv_head_stride
+ idx * VEC_SIZE;
V_vec v_vecs[NUM_ROUNDS_PER_TOKEN];
VVecT v_vecs[NUM_ROUNDS_PER_TOKEN];
#pragma unroll
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
v_vecs[i] = (reinterpret_cast<const V_vec*>(v_ptr))[i * WARP_SIZE];
v_vecs[i] = CastFunctor<VQuantVecT, VVecT>()(*((reinterpret_cast<const VQuantVecT*>(v_ptr) + i * WARP_SIZE)));
}
if (token_idx >= context_len) {
@ -210,7 +225,7 @@ __global__ void flash_decoding_attention_kernel(
logit = CastFunctor<float, scalar_t>()(logits[token_idx]);
#pragma unroll
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
accs[i] = TernaryOpFunctor<scalar_t, V_vec, Float_vec, TernaryOpType::kFma>()(logit, v_vecs[i], accs[i]);
accs[i] = TernaryOpFunctor<scalar_t, VVecT, FloatVecT, TernaryOpType::kFma>()(logit, v_vecs[i], accs[i]);
}
}
}
@ -220,16 +235,16 @@ __global__ void flash_decoding_attention_kernel(
#pragma unroll
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
block_sum<Float_vec, NUM_WARPS, NUM_THREADS_PER_TOKEN, VEC_SIZE>(out_shared_mem, accs[i]);
block_sum<FloatVecT, NUM_WARPS, NUM_THREADS_PER_TOKEN, VEC_SIZE>(out_shared_mem, accs[i]);
}
scalar_t* out_ptr = out + seq_idx * q_stride + head_idx * HEAD_SIZE;
L_vec out_reg;
LVecT out_reg;
#pragma unroll
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
if (thread_idx < NUM_THREADS_PER_TOKEN) {
out_reg = CastFunctor<Float_vec, L_vec>()(accs[i]);
(reinterpret_cast<L_vec*>(out_ptr))[thread_idx + i * NUM_THREADS_PER_TOKEN] = out_reg;
out_reg = CastFunctor<FloatVecT, LVecT>()(accs[i]);
(reinterpret_cast<LVecT*>(out_ptr))[thread_idx + i * NUM_THREADS_PER_TOKEN] = out_reg;
}
}
}
@ -353,18 +368,40 @@ void flash_decoding_attention(
torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
float scale) {
switch (query.scalar_type()) {
case at::ScalarType::Float:
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float);
break;
case at::ScalarType::Half:
CALL_V1_LAUNCHER_BLOCK_SIZE(half, half);
break;
case at::ScalarType::BFloat16:
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16);
break;
default:
AT_ERROR("Unsupported data type: ", toString(query.scalar_type()));
TORCH_CHECK(query.scalar_type() == at::ScalarType::Float || query.scalar_type() == at::ScalarType::Half || query.scalar_type() == at::ScalarType::BFloat16,
"Dtype of query should be float, half or bfloat16!");
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == key_cache.scalar_type(),
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");
if(key_cache.scalar_type() == at::ScalarType::Byte)
{
switch (query.scalar_type()) {
case at::ScalarType::Float:
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t);
break;
case at::ScalarType::Half:
CALL_V1_LAUNCHER_BLOCK_SIZE(half, uint8_t);
break;
case at::ScalarType::BFloat16:
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t);
break;
}
}
else
{
switch (query.scalar_type()) {
case at::ScalarType::Float:
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float);
break;
case at::ScalarType::Half:
CALL_V1_LAUNCHER_BLOCK_SIZE(half, half);
break;
case at::ScalarType::BFloat16:
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16);
break;
}
}
}