mirror of https://github.com/hpcaitech/ColossalAI
[Inference/Feat] Add kvcache quantization support for FlashDecoding (#5656)
parent
5be590b99e
commit
8ccb6714e7
|
@ -5,6 +5,7 @@
|
|||
#include <cuda_fp16.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "common/data_type.h"
|
||||
|
@ -27,6 +28,7 @@ struct FloatVecTypeTrait {};
|
|||
VEC_TYPE_TRAITS_SPECIALIZATION(T, 1, T, typename T)
|
||||
|
||||
#if defined(COLOSSAL_WITH_CUDA)
|
||||
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 1, __nv_bfloat16)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 2, __nv_bfloat162)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 4, float2)
|
||||
|
@ -35,18 +37,19 @@ VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 1, half)
|
|||
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 2, half2)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 4, float2)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 8, float4)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8_)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 2, half)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 4, half2)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 8, float2)
|
||||
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 2, uint16_t)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 4, uint32_t)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 8, uint2)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 2, __nv_bfloat162);
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 4, dtype::bfloat164);
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 8, dtype::bfloat168);
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(half, 2, half2);
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(half, 4, dtype::half4);
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(half, 8, dtype::half8);
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8_)
|
||||
#endif /* defined(COLOSSAL_WITH_CUDA) */
|
||||
|
||||
#undef VEC_TYPE_TRAITS_SPECIALIZATION
|
||||
|
|
|
@ -4,9 +4,12 @@
|
|||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "common/data_type.h"
|
||||
|
@ -23,141 +26,421 @@ struct CastFunctor : public std::unary_function<From, To> {
|
|||
HOSTDEVICE To operator()(From val) { return static_cast<To>(val); }
|
||||
};
|
||||
|
||||
#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, STMTS, \
|
||||
FUNCTION_MODIFIER) \
|
||||
template <> \
|
||||
struct CastFunctor<FROM, TO> : public std::unary_function<FROM, TO> { \
|
||||
FUNCTION_MODIFIER TO operator()(FROM val) STMTS \
|
||||
#define STMTS_WRAPPER(...) __VA_ARGS__
|
||||
|
||||
#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, FUNCTION_MODIFIER, \
|
||||
STMTS) \
|
||||
template <> \
|
||||
struct CastFunctor<FROM, TO> : public std::unary_function<FROM, TO> { \
|
||||
FUNCTION_MODIFIER TO operator()(FROM val) STMTS \
|
||||
};
|
||||
|
||||
#if defined(COLOSSAL_WITH_CUDA)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
int2, float2, { return make_float2(val.x, val.y); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float, float2, { return make_float2(val, val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(int2, float2, DEVICE, STMTS_WRAPPER({
|
||||
return make_float2(val.x, val.y);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, float2, DEVICE, STMTS_WRAPPER({
|
||||
return make_float2(val, val);
|
||||
}))
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
half2, float2, { return __half22float2(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float2, half2, { return __float22half2_rn(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float, half, { return __float2half_rn(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float, half2, { return __float2half2_rn(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
half, half2, { return __half2half2(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
half, float, { return __half2float(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float4, dtype::half4,
|
||||
{
|
||||
dtype::half4 dst;
|
||||
dst.x = __floats2half2_rn(val.x, val.y);
|
||||
dst.y = __floats2half2_rn(val.z, val.w);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float4_, dtype::half4,
|
||||
{
|
||||
dtype::half4 dst;
|
||||
dst.x = __float22half2_rn(val.x);
|
||||
dst.y = __float22half2_rn(val.y);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float8_, dtype::half8,
|
||||
{
|
||||
dtype::half8 dst;
|
||||
dst.x = __float22half2_rn(val.x);
|
||||
dst.y = __float22half2_rn(val.y);
|
||||
dst.z = __float22half2_rn(val.z);
|
||||
dst.w = __float22half2_rn(val.w);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, float2, DEVICE, STMTS_WRAPPER({
|
||||
return __half22float2(val);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, half2, DEVICE, STMTS_WRAPPER({
|
||||
return __float22half2_rn(val);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half, DEVICE, STMTS_WRAPPER({
|
||||
return __float2half_rn(val);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half2, DEVICE, STMTS_WRAPPER({
|
||||
return __float2half2_rn(val);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, half2, DEVICE, STMTS_WRAPPER({
|
||||
return __half2half2(val);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, float, DEVICE, STMTS_WRAPPER({
|
||||
return __half2float(val);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::half4, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::half4 dst;
|
||||
dst.x = __floats2half2_rn(val.x, val.y);
|
||||
dst.y = __floats2half2_rn(val.z, val.w);
|
||||
return dst;
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, dtype::half4, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::half4 dst;
|
||||
dst.x = __float22half2_rn(val.x);
|
||||
dst.y = __float22half2_rn(val.y);
|
||||
return dst;
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8_, dtype::half8, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::half8 dst;
|
||||
dst.x = __float22half2_rn(val.x);
|
||||
dst.y = __float22half2_rn(val.y);
|
||||
dst.z = __float22half2_rn(val.z);
|
||||
dst.w = __float22half2_rn(val.w);
|
||||
return dst;
|
||||
}))
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float, __nv_bfloat162, { return __float2bfloat162_rn(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float, __nv_bfloat16, { return __float2bfloat16_rn(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float4, dtype::bfloat164,
|
||||
{
|
||||
dtype::bfloat164 dst;
|
||||
dst.x = __floats2bfloat162_rn(val.x, val.y);
|
||||
dst.y = __floats2bfloat162_rn(val.z, val.w);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat162, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __float2bfloat162_rn(val);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat16, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __float2bfloat16_rn(val);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::bfloat164, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::bfloat164 dst;
|
||||
dst.x =
|
||||
__floats2bfloat162_rn(val.x, val.y);
|
||||
dst.y =
|
||||
__floats2bfloat162_rn(val.z, val.w);
|
||||
return dst;
|
||||
}))
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, __nv_bfloat162, { return __bfloat162bfloat162(val); },
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, float2, { return __bfloat1622float2(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float2, __nv_bfloat162, { return __float22bfloat162_rn(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float4_, dtype::bfloat164,
|
||||
{
|
||||
dtype::bfloat164 dst;
|
||||
dst.x = __float22bfloat162_rn(val.x);
|
||||
dst.y = __float22bfloat162_rn(val.y);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float8_, dtype::bfloat168,
|
||||
{
|
||||
dtype::bfloat168 dst;
|
||||
dst.x = __float22bfloat162_rn(val.x);
|
||||
dst.y = __float22bfloat162_rn(val.y);
|
||||
dst.z = __float22bfloat162_rn(val.z);
|
||||
dst.w = __float22bfloat162_rn(val.w);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat162, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __bfloat162bfloat162(val);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat162, float2, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __bfloat1622float2(val);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, __nv_bfloat162, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __float22bfloat162_rn(val);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, dtype::bfloat164, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::bfloat164 dst;
|
||||
dst.x = __float22bfloat162_rn(val.x);
|
||||
dst.y = __float22bfloat162_rn(val.y);
|
||||
return dst;
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8_, dtype::bfloat168, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::bfloat168 dst;
|
||||
dst.x = __float22bfloat162_rn(val.x);
|
||||
dst.y = __float22bfloat162_rn(val.y);
|
||||
dst.z = __float22bfloat162_rn(val.z);
|
||||
dst.w = __float22bfloat162_rn(val.w);
|
||||
return dst;
|
||||
}))
|
||||
#else
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat162, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
__nv_bfloat162 dst;
|
||||
dst.x = val;
|
||||
dst.y = val;
|
||||
return dst;
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat162, float2, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return make_float2(__low2float(val),
|
||||
__high2float(val));
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, __nv_bfloat162, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __floats2bfloat162_rn(val.x,
|
||||
val.y);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, __nv_bfloat162,
|
||||
{
|
||||
__nv_bfloat162 dst;
|
||||
dst.x = val;
|
||||
dst.y = val;
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, float2,
|
||||
{ return make_float2(__low2float(val), __high2float(val)); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float2, __nv_bfloat162, { return __floats2bfloat162_rn(val.x, val.y); },
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float4_, dtype::bfloat164,
|
||||
{
|
||||
dtype::float4_, dtype::bfloat164, DEVICE, STMTS_WRAPPER({
|
||||
dtype::bfloat164 dst;
|
||||
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
|
||||
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float8_, dtype::bfloat168,
|
||||
{
|
||||
dtype::float8_, dtype::bfloat168, DEVICE, STMTS_WRAPPER({
|
||||
dtype::bfloat168 dst;
|
||||
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
|
||||
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
|
||||
dst.z = __floats2bfloat162_rn(val.z.x, val.z.y);
|
||||
dst.w = __floats2bfloat162_rn(val.w.x, val.w.y);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
}))
|
||||
#endif /* defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 */
|
||||
|
||||
// quant utils
|
||||
// fp8 -> half raw
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint8_t, uint16_t, DEVICE, STMTS_WRAPPER({
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(
|
||||
val, __NV_E5M2);
|
||||
return res.x;
|
||||
}))
|
||||
|
||||
// fp8x2 -> half2 raw
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint32_t, DEVICE, STMTS_WRAPPER({
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
__half2_raw res =
|
||||
__nv_cvt_fp8x2_to_halfraw2(
|
||||
val, __NV_E5M2);
|
||||
tmp.u16[0] = res.x;
|
||||
tmp.u16[1] = res.y;
|
||||
return tmp.u32;
|
||||
}))
|
||||
|
||||
// fp8x4 -> half2x2 raw
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint32_t, uint2, DEVICE, STMTS_WRAPPER({
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] =
|
||||
CastFunctor<uint16_t, uint32_t>()(static_cast<uint16_t>(val));
|
||||
tmp.u32[1] =
|
||||
CastFunctor<uint16_t, uint32_t>()(static_cast<uint16_t>(val >> 16U));
|
||||
return tmp.u32x2;
|
||||
}))
|
||||
|
||||
// fp8x8 -> half2x4 raw
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint2, uint4, DEVICE, STMTS_WRAPPER({
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = CastFunctor<uint32_t, uint2>()(val.x);
|
||||
tmp.u64[1] = CastFunctor<uint32_t, uint2>()(val.y);
|
||||
return tmp.u64x2;
|
||||
}))
|
||||
|
||||
// fp8 -> half
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint8_t, half, DEVICE, STMTS_WRAPPER({
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(
|
||||
val, __NV_E5M2);
|
||||
return half(res);
|
||||
}))
|
||||
|
||||
// fp8x2 -> half2
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, half2, DEVICE, STMTS_WRAPPER({
|
||||
__half2_raw res =
|
||||
__nv_cvt_fp8x2_to_halfraw2(
|
||||
val, __NV_E5M2);
|
||||
return half2(res);
|
||||
}))
|
||||
|
||||
// fp8x4 -> half4
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint32_t, dtype::half4, DEVICE, STMTS_WRAPPER({
|
||||
half2 tmp1, tmp2;
|
||||
tmp1 = CastFunctor<uint16_t, half2>()(static_cast<uint16_t>(val));
|
||||
tmp2 = CastFunctor<uint16_t, half2>()(static_cast<uint16_t>(val >> 16U));
|
||||
dtype::half4 res;
|
||||
res.x = tmp1;
|
||||
res.y = tmp2;
|
||||
return res;
|
||||
}))
|
||||
|
||||
// fp8x8 -> half8
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint2, dtype::half8, DEVICE, STMTS_WRAPPER({
|
||||
dtype::half4 tmp1, tmp2;
|
||||
tmp1 = CastFunctor<uint32_t, dtype::half4>()(val.x);
|
||||
tmp2 = CastFunctor<uint32_t, dtype::half4>()(val.y);
|
||||
dtype::half8 res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}))
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint8_t, __nv_bfloat16, DEVICE, STMTS_WRAPPER({
|
||||
// Note there is no direct convert function from fp8 to bf16.
|
||||
// fp8 -> half
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(val, __NV_E5M2);
|
||||
// half -> float -> bf16
|
||||
float tmp;
|
||||
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(tmp) : "h"(res.x));
|
||||
return __float2bfloat16(tmp);
|
||||
}))
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint16_t, __nv_bfloat162, DEVICE, STMTS_WRAPPER({
|
||||
__nv_bfloat162 res;
|
||||
res.x = CastFunctor<uint8_t, __nv_bfloat16>()(static_cast<uint8_t>(val));
|
||||
res.y = CastFunctor<uint8_t, __nv_bfloat16>()(
|
||||
static_cast<uint8_t>(val >> 8U));
|
||||
return res;
|
||||
}))
|
||||
|
||||
// fp8x4 -> bfloat164
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint32_t, dtype::bfloat164, DEVICE, STMTS_WRAPPER({
|
||||
dtype::bfloat164 res;
|
||||
res.x =
|
||||
CastFunctor<uint16_t, __nv_bfloat162>()(static_cast<uint16_t>(val));
|
||||
res.y = CastFunctor<uint16_t, __nv_bfloat162>()(
|
||||
static_cast<uint16_t>(val >> 16U));
|
||||
return res;
|
||||
}))
|
||||
|
||||
// fp8x8 -> bfloat168
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint2, dtype::bfloat168, DEVICE, STMTS_WRAPPER({
|
||||
dtype::bfloat164 tmp1, tmp2;
|
||||
tmp1 = CastFunctor<uint32_t, dtype::bfloat164>()(val.x);
|
||||
tmp2 = CastFunctor<uint32_t, dtype::bfloat164>()(val.y);
|
||||
dtype::bfloat168 res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}))
|
||||
|
||||
// fp8 -> float
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint8_t, float, DEVICE, STMTS_WRAPPER({
|
||||
// fp8 -> half
|
||||
uint16_t tmp = CastFunctor<uint8_t, uint16_t>()(val);
|
||||
// half -> float
|
||||
float res;
|
||||
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(res) : "h"(tmp));
|
||||
return res;
|
||||
}))
|
||||
|
||||
// fp8x2 -> float2
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint16_t, float2, DEVICE, STMTS_WRAPPER({
|
||||
// fp8x2 -> half2
|
||||
uint32_t tmp = CastFunctor<uint16_t, uint32_t>()(val);
|
||||
// half2 -> float2
|
||||
uint16_t lo, hi;
|
||||
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(tmp));
|
||||
float lof, hif;
|
||||
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(lof) : "h"(lo));
|
||||
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(hif) : "h"(hi));
|
||||
return make_float2(lof, hif);
|
||||
}))
|
||||
|
||||
// fp8x4 -> float4_
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint32_t, dtype::float4_, DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ res;
|
||||
res.x = CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val));
|
||||
res.y =
|
||||
CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val >> 16U));
|
||||
return res;
|
||||
}))
|
||||
|
||||
// fp8x8 -> float8_
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint2, dtype::float8_, DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ tmp1, tmp2;
|
||||
tmp1 = CastFunctor<uint32_t, dtype::float4_>()(val.x);
|
||||
tmp2 = CastFunctor<uint32_t, dtype::float4_>()(val.y);
|
||||
dtype::float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}))
|
||||
|
||||
// half -> fp8
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint8_t, DEVICE, STMTS_WRAPPER({
|
||||
__half_raw tmp;
|
||||
tmp.x = val;
|
||||
__nv_fp8_storage_t res =
|
||||
__nv_cvt_halfraw_to_fp8(
|
||||
tmp, __NV_SATFINITE, __NV_E5M2);
|
||||
return static_cast<uint8_t>(res);
|
||||
}))
|
||||
|
||||
// bf16 -> fp8
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, uint8_t, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
__nv_fp8_storage_t res =
|
||||
__nv_cvt_bfloat16raw_to_fp8(
|
||||
__nv_bfloat16_raw(val),
|
||||
__NV_SATFINITE, __NV_E5M2);
|
||||
return static_cast<uint8_t>(res);
|
||||
#endif
|
||||
}))
|
||||
|
||||
// float -> fp8
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, uint8_t, DEVICE, STMTS_WRAPPER({
|
||||
__nv_fp8_storage_t res =
|
||||
__nv_cvt_float_to_fp8(
|
||||
val, __NV_SATFINITE, __NV_E5M2);
|
||||
return static_cast<uint8_t>(res);
|
||||
}))
|
||||
|
||||
// fp8x4 -> float4
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint32_t, float4, DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ tmp = CastFunctor<uint32_t, dtype::float4_>()(val);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}))
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, uint32_t, DEVICE, STMTS_WRAPPER({
|
||||
union {
|
||||
half2 float16;
|
||||
uint32_t uint32;
|
||||
};
|
||||
|
||||
float16 = __float22half2_rn(val);
|
||||
return uint32;
|
||||
}))
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, uint2, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
uint2 b;
|
||||
float2 c;
|
||||
c.x = val.x.x;
|
||||
c.y = val.x.y;
|
||||
b.x = CastFunctor<float2, uint32_t>()(c);
|
||||
|
||||
c.x = val.y.x;
|
||||
c.y = val.y.y;
|
||||
b.y = CastFunctor<float2, uint32_t>()(c);
|
||||
|
||||
return b;
|
||||
}))
|
||||
|
||||
// float4_ -> float4
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, float4, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float4 b;
|
||||
b.x = val.x.x;
|
||||
b.y = val.x.y;
|
||||
b.z = val.y.x;
|
||||
b.w = val.y.y;
|
||||
return b;
|
||||
}))
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float8_, uint4, DEVICE, STMTS_WRAPPER({
|
||||
uint4 b;
|
||||
b.x = CastFunctor<float2, uint32_t>()(val.x);
|
||||
b.y = CastFunctor<float2, uint32_t>()(val.y);
|
||||
b.z = CastFunctor<float2, uint32_t>()(val.z);
|
||||
b.w = CastFunctor<float2, uint32_t>()(val.w);
|
||||
return b;
|
||||
}))
|
||||
|
||||
#endif /* defined(COLOSSAL_WITH_CUDA) */
|
||||
|
||||
#undef STMTS_WRAPPER
|
||||
#undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION
|
||||
} // namespace funcs
|
||||
} // namespace colossalAI
|
||||
|
|
|
@ -15,21 +15,6 @@
|
|||
namespace colossalAI {
|
||||
namespace funcs {
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void zero(T& dst) {
|
||||
constexpr int WORDS = sizeof(T) / 4;
|
||||
union {
|
||||
T raw;
|
||||
uint32_t words[WORDS];
|
||||
} tmp;
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < WORDS; ii++) {
|
||||
tmp.words[ii] = 0u;
|
||||
}
|
||||
dst = tmp.raw;
|
||||
}
|
||||
|
||||
// Note(LiuYang): As a retrieved table to check which operation is supported
|
||||
// already
|
||||
enum class UnaryOpType { kLog2Ceil = 0, kAbs, kSum };
|
||||
|
|
|
@ -174,13 +174,13 @@ void context_kv_cache_memcpy(
|
|||
key.scalar_type(),
|
||||
"context_kv_cache_memcpy",
|
||||
apply_context_kv_cache_memcpy<scalar_t>(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
sequence_lengths,
|
||||
cu_seqlens,
|
||||
block_tables,
|
||||
max_seq_len_in_batch
|
||||
);)
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
sequence_lengths,
|
||||
cu_seqlens,
|
||||
block_tables,
|
||||
max_seq_len_in_batch
|
||||
);)
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "common/micros.h"
|
||||
#include "funcs/cast_functor.h"
|
||||
|
@ -34,11 +33,25 @@ constexpr unsigned int nextHighestPowerOf2(unsigned int v) {
|
|||
return v;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void zero(T& dst) {
|
||||
constexpr int WORDS = sizeof(T) / 4;
|
||||
union {
|
||||
T raw;
|
||||
uint32_t words[WORDS];
|
||||
} tmp;
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < WORDS; ii++) {
|
||||
tmp.words[ii] = 0u;
|
||||
}
|
||||
dst = tmp.raw;
|
||||
}
|
||||
|
||||
using colossalAI::funcs::BinaryOpType;
|
||||
using colossalAI::funcs::CastFunctor;
|
||||
using colossalAI::funcs::TernaryOpFunctor;
|
||||
using colossalAI::funcs::TernaryOpType;
|
||||
using colossalAI::funcs::zero;
|
||||
using colossalAI::common::VecTypeTrait;
|
||||
using colossalAI::common::FloatVecTypeTrait;
|
||||
using namespace colossalAI::cuda::attention;
|
||||
|
@ -84,10 +97,12 @@ __global__ void flash_decoding_attention_kernel(
|
|||
constexpr int NUM_ROWS_PER_ROUNDS = MIN(WARP_SIZE / NUM_THREADS_PER_X, BLOCK_SIZE);
|
||||
constexpr int NUM_VECS_PER_THREAD = NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN / WARP_SIZE;
|
||||
|
||||
using K_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
|
||||
using V_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
|
||||
using L_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
|
||||
using Float_vec = typename FloatVecTypeTrait<L_vec>::Type;
|
||||
using KVecT = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
|
||||
using VVecT = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
|
||||
using KQuantVecT = typename VecTypeTrait<cache_t, VEC_SIZE>::Type;
|
||||
using VQuantVecT = typename VecTypeTrait<cache_t, VEC_SIZE>::Type;
|
||||
using LVecT = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
|
||||
using FloatVecT = typename FloatVecTypeTrait<LVecT>::Type;
|
||||
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int thread_group_offset = lane % NUM_THREADS_PER_X;
|
||||
|
@ -119,18 +134,18 @@ __global__ void flash_decoding_attention_kernel(
|
|||
scalar_t* q_shared_ptr = reinterpret_cast<scalar_t*>(q_shared);
|
||||
// each warp access a whole block
|
||||
|
||||
K_vec q_vecs[NUM_VECS_PER_THREAD];
|
||||
KVecT q_vecs[NUM_VECS_PER_THREAD];
|
||||
#pragma unroll
|
||||
for (int idx = lane, i = 0; idx < NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN; idx += WARP_SIZE, i += 1) {
|
||||
const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS;
|
||||
const int offset1 = idx % NUM_THREADS_PER_X;
|
||||
q_vecs[i] = *reinterpret_cast<K_vec*>(q_shared_ptr + offset0 * x + offset1 * VEC_SIZE);
|
||||
q_vecs[i] = *reinterpret_cast<KVecT*>(q_shared_ptr + offset0 * x + offset1 * VEC_SIZE);
|
||||
}
|
||||
|
||||
for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
|
||||
const int64_t physical_block_number = static_cast<int64_t>(block_table_shared[block_idx]);
|
||||
|
||||
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
||||
KVecT k_vecs[NUM_VECS_PER_THREAD];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < BLOCK_SIZE; i += NUM_ROWS_PER_ROUNDS) {
|
||||
|
@ -142,7 +157,7 @@ __global__ void flash_decoding_attention_kernel(
|
|||
const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS;
|
||||
const int offset1 = (idx / NUM_THREADS_PER_X) % NUM_ROWS_PER_ROUNDS;
|
||||
const int offset2 = idx % NUM_THREADS_PER_X;
|
||||
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset0 * BLOCK_SIZE * x + offset1 * x + offset2 * VEC_SIZE);
|
||||
k_vecs[j] = CastFunctor<KQuantVecT, KVecT>()(*reinterpret_cast<const KQuantVecT*>(k_ptr + offset0 * BLOCK_SIZE * x + offset1 * x + offset2 * VEC_SIZE));
|
||||
}
|
||||
|
||||
float qk = scale * Qk_dot<scalar_t, NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X, NUM_THREADS_PER_X>::dot(q_vecs, k_vecs);
|
||||
|
@ -174,13 +189,13 @@ __global__ void flash_decoding_attention_kernel(
|
|||
}
|
||||
__syncthreads();
|
||||
|
||||
Float_vec accs[NUM_ROUNDS_PER_TOKEN];
|
||||
FloatVecT accs[NUM_ROUNDS_PER_TOKEN];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
||||
zero(accs[i]);
|
||||
}
|
||||
|
||||
V_vec zero_value;
|
||||
VVecT zero_value;
|
||||
zero(zero_value);
|
||||
for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
|
||||
const int64_t physical_block_number = static_cast<int64_t>(block_table_shared[block_idx]);
|
||||
|
@ -193,11 +208,11 @@ __global__ void flash_decoding_attention_kernel(
|
|||
+ kv_head_idx * kv_head_stride
|
||||
+ idx * VEC_SIZE;
|
||||
|
||||
V_vec v_vecs[NUM_ROUNDS_PER_TOKEN];
|
||||
VVecT v_vecs[NUM_ROUNDS_PER_TOKEN];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
||||
v_vecs[i] = (reinterpret_cast<const V_vec*>(v_ptr))[i * WARP_SIZE];
|
||||
v_vecs[i] = CastFunctor<VQuantVecT, VVecT>()(*((reinterpret_cast<const VQuantVecT*>(v_ptr) + i * WARP_SIZE)));
|
||||
}
|
||||
|
||||
if (token_idx >= context_len) {
|
||||
|
@ -210,7 +225,7 @@ __global__ void flash_decoding_attention_kernel(
|
|||
logit = CastFunctor<float, scalar_t>()(logits[token_idx]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
||||
accs[i] = TernaryOpFunctor<scalar_t, V_vec, Float_vec, TernaryOpType::kFma>()(logit, v_vecs[i], accs[i]);
|
||||
accs[i] = TernaryOpFunctor<scalar_t, VVecT, FloatVecT, TernaryOpType::kFma>()(logit, v_vecs[i], accs[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -220,16 +235,16 @@ __global__ void flash_decoding_attention_kernel(
|
|||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
||||
block_sum<Float_vec, NUM_WARPS, NUM_THREADS_PER_TOKEN, VEC_SIZE>(out_shared_mem, accs[i]);
|
||||
block_sum<FloatVecT, NUM_WARPS, NUM_THREADS_PER_TOKEN, VEC_SIZE>(out_shared_mem, accs[i]);
|
||||
}
|
||||
|
||||
scalar_t* out_ptr = out + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
L_vec out_reg;
|
||||
LVecT out_reg;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
|
||||
if (thread_idx < NUM_THREADS_PER_TOKEN) {
|
||||
out_reg = CastFunctor<Float_vec, L_vec>()(accs[i]);
|
||||
(reinterpret_cast<L_vec*>(out_ptr))[thread_idx + i * NUM_THREADS_PER_TOKEN] = out_reg;
|
||||
out_reg = CastFunctor<FloatVecT, LVecT>()(accs[i]);
|
||||
(reinterpret_cast<LVecT*>(out_ptr))[thread_idx + i * NUM_THREADS_PER_TOKEN] = out_reg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -353,18 +368,40 @@ void flash_decoding_attention(
|
|||
torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
|
||||
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
|
||||
float scale) {
|
||||
switch (query.scalar_type()) {
|
||||
case at::ScalarType::Float:
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float);
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(half, half);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16);
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("Unsupported data type: ", toString(query.scalar_type()));
|
||||
|
||||
|
||||
TORCH_CHECK(query.scalar_type() == at::ScalarType::Float || query.scalar_type() == at::ScalarType::Half || query.scalar_type() == at::ScalarType::BFloat16,
|
||||
"Dtype of query should be float, half or bfloat16!");
|
||||
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == key_cache.scalar_type(),
|
||||
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");
|
||||
|
||||
if(key_cache.scalar_type() == at::ScalarType::Byte)
|
||||
{
|
||||
switch (query.scalar_type()) {
|
||||
case at::ScalarType::Float:
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t);
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(half, uint8_t);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t);
|
||||
break;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
switch (query.scalar_type()) {
|
||||
case at::ScalarType::Float:
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float);
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(half, half);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue