diff --git a/extensions/csrc/common/vec_type_traits.h b/extensions/csrc/common/vec_type_traits.h index 6ea6d7a38..f7e70e22c 100644 --- a/extensions/csrc/common/vec_type_traits.h +++ b/extensions/csrc/common/vec_type_traits.h @@ -5,6 +5,7 @@ #include #endif +#include #include #include "common/data_type.h" @@ -27,6 +28,7 @@ struct FloatVecTypeTrait {}; VEC_TYPE_TRAITS_SPECIALIZATION(T, 1, T, typename T) #if defined(COLOSSAL_WITH_CUDA) + VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 1, __nv_bfloat16) VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 2, __nv_bfloat162) VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 4, float2) @@ -35,18 +37,19 @@ VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 1, half) VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 2, half2) VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 4, float2) VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 8, float4) -VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2) -VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4) -VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8_) -VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 2, half) -VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 4, half2) -VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 8, float2) + +VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 2, uint16_t) +VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 4, uint32_t) +VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 8, uint2) VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 2, __nv_bfloat162); VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 4, dtype::bfloat164); VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 8, dtype::bfloat168); VEC_TYPE_TRAITS_SPECIALIZATION(half, 2, half2); VEC_TYPE_TRAITS_SPECIALIZATION(half, 4, dtype::half4); VEC_TYPE_TRAITS_SPECIALIZATION(half, 8, dtype::half8); +VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2) +VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4) +VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8_) #endif /* defined(COLOSSAL_WITH_CUDA) */ #undef VEC_TYPE_TRAITS_SPECIALIZATION diff --git a/extensions/csrc/funcs/cast_functor.h b/extensions/csrc/funcs/cast_functor.h index 7fc22fb44..d33eece59 100644 --- a/extensions/csrc/funcs/cast_functor.h +++ b/extensions/csrc/funcs/cast_functor.h @@ -4,9 +4,12 @@ #include #include #include +#include #include #endif +#include + #include #include "common/data_type.h" @@ -23,141 +26,421 @@ struct CastFunctor : public std::unary_function { HOSTDEVICE To operator()(From val) { return static_cast(val); } }; -#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, STMTS, \ - FUNCTION_MODIFIER) \ - template <> \ - struct CastFunctor : public std::unary_function { \ - FUNCTION_MODIFIER TO operator()(FROM val) STMTS \ +#define STMTS_WRAPPER(...) __VA_ARGS__ + +#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, FUNCTION_MODIFIER, \ + STMTS) \ + template <> \ + struct CastFunctor : public std::unary_function { \ + FUNCTION_MODIFIER TO operator()(FROM val) STMTS \ }; #if defined(COLOSSAL_WITH_CUDA) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - int2, float2, { return make_float2(val.x, val.y); }, DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - float, float2, { return make_float2(val, val); }, DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(int2, float2, DEVICE, STMTS_WRAPPER({ + return make_float2(val.x, val.y); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, float2, DEVICE, STMTS_WRAPPER({ + return make_float2(val, val); + })) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - half2, float2, { return __half22float2(val); }, DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - float2, half2, { return __float22half2_rn(val); }, DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - float, half, { return __float2half_rn(val); }, DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - float, half2, { return __float2half2_rn(val); }, DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - half, half2, { return __half2half2(val); }, DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - half, float, { return __half2float(val); }, DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - float4, dtype::half4, - { - dtype::half4 dst; - dst.x = __floats2half2_rn(val.x, val.y); - dst.y = __floats2half2_rn(val.z, val.w); - return dst; - }, - DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - dtype::float4_, dtype::half4, - { - dtype::half4 dst; - dst.x = __float22half2_rn(val.x); - dst.y = __float22half2_rn(val.y); - return dst; - }, - DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - dtype::float8_, dtype::half8, - { - dtype::half8 dst; - dst.x = __float22half2_rn(val.x); - dst.y = __float22half2_rn(val.y); - dst.z = __float22half2_rn(val.z); - dst.w = __float22half2_rn(val.w); - return dst; - }, - DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, float2, DEVICE, STMTS_WRAPPER({ + return __half22float2(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, half2, DEVICE, STMTS_WRAPPER({ + return __float22half2_rn(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half, DEVICE, STMTS_WRAPPER({ + return __float2half_rn(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half2, DEVICE, STMTS_WRAPPER({ + return __float2half2_rn(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, half2, DEVICE, STMTS_WRAPPER({ + return __half2half2(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, float, DEVICE, STMTS_WRAPPER({ + return __half2float(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::half4, DEVICE, + STMTS_WRAPPER({ + dtype::half4 dst; + dst.x = __floats2half2_rn(val.x, val.y); + dst.y = __floats2half2_rn(val.z, val.w); + return dst; + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, dtype::half4, DEVICE, + STMTS_WRAPPER({ + dtype::half4 dst; + dst.x = __float22half2_rn(val.x); + dst.y = __float22half2_rn(val.y); + return dst; + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8_, dtype::half8, DEVICE, + STMTS_WRAPPER({ + dtype::half8 dst; + dst.x = __float22half2_rn(val.x); + dst.y = __float22half2_rn(val.y); + dst.z = __float22half2_rn(val.z); + dst.w = __float22half2_rn(val.w); + return dst; + })) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - float, __nv_bfloat162, { return __float2bfloat162_rn(val); }, DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - float, __nv_bfloat16, { return __float2bfloat16_rn(val); }, DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - float4, dtype::bfloat164, - { - dtype::bfloat164 dst; - dst.x = __floats2bfloat162_rn(val.x, val.y); - dst.y = __floats2bfloat162_rn(val.z, val.w); - return dst; - }, - DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat162, DEVICE, + STMTS_WRAPPER({ + return __float2bfloat162_rn(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat16, DEVICE, + STMTS_WRAPPER({ + return __float2bfloat16_rn(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::bfloat164, DEVICE, + STMTS_WRAPPER({ + dtype::bfloat164 dst; + dst.x = + __floats2bfloat162_rn(val.x, val.y); + dst.y = + __floats2bfloat162_rn(val.z, val.w); + return dst; + })) #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - __nv_bfloat16, __nv_bfloat162, { return __bfloat162bfloat162(val); }, - DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - __nv_bfloat162, float2, { return __bfloat1622float2(val); }, DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - float2, __nv_bfloat162, { return __float22bfloat162_rn(val); }, DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - dtype::float4_, dtype::bfloat164, - { - dtype::bfloat164 dst; - dst.x = __float22bfloat162_rn(val.x); - dst.y = __float22bfloat162_rn(val.y); - return dst; - }, - DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - dtype::float8_, dtype::bfloat168, - { - dtype::bfloat168 dst; - dst.x = __float22bfloat162_rn(val.x); - dst.y = __float22bfloat162_rn(val.y); - dst.z = __float22bfloat162_rn(val.z); - dst.w = __float22bfloat162_rn(val.w); - return dst; - }, - DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat162, DEVICE, + STMTS_WRAPPER({ + return __bfloat162bfloat162(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat162, float2, DEVICE, + STMTS_WRAPPER({ + return __bfloat1622float2(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, __nv_bfloat162, DEVICE, + STMTS_WRAPPER({ + return __float22bfloat162_rn(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, dtype::bfloat164, DEVICE, + STMTS_WRAPPER({ + dtype::bfloat164 dst; + dst.x = __float22bfloat162_rn(val.x); + dst.y = __float22bfloat162_rn(val.y); + return dst; + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8_, dtype::bfloat168, DEVICE, + STMTS_WRAPPER({ + dtype::bfloat168 dst; + dst.x = __float22bfloat162_rn(val.x); + dst.y = __float22bfloat162_rn(val.y); + dst.z = __float22bfloat162_rn(val.z); + dst.w = __float22bfloat162_rn(val.w); + return dst; + })) #else +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat162, DEVICE, + STMTS_WRAPPER({ + __nv_bfloat162 dst; + dst.x = val; + dst.y = val; + return dst; + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat162, float2, DEVICE, + STMTS_WRAPPER({ + return make_float2(__low2float(val), + __high2float(val)); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, __nv_bfloat162, DEVICE, + STMTS_WRAPPER({ + return __floats2bfloat162_rn(val.x, + val.y); + })) COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - __nv_bfloat16, __nv_bfloat162, - { - __nv_bfloat162 dst; - dst.x = val; - dst.y = val; - return dst; - }, - DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - __nv_bfloat162, float2, - { return make_float2(__low2float(val), __high2float(val)); }, DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - float2, __nv_bfloat162, { return __floats2bfloat162_rn(val.x, val.y); }, - DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - dtype::float4_, dtype::bfloat164, - { + dtype::float4_, dtype::bfloat164, DEVICE, STMTS_WRAPPER({ dtype::bfloat164 dst; dst.x = __floats2bfloat162_rn(val.x.x, val.x.y); dst.y = __floats2bfloat162_rn(val.y.x, val.y.y); return dst; - }, - DEVICE) + })) COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - dtype::float8_, dtype::bfloat168, - { + dtype::float8_, dtype::bfloat168, DEVICE, STMTS_WRAPPER({ dtype::bfloat168 dst; dst.x = __floats2bfloat162_rn(val.x.x, val.x.y); dst.y = __floats2bfloat162_rn(val.y.x, val.y.y); dst.z = __floats2bfloat162_rn(val.z.x, val.z.y); dst.w = __floats2bfloat162_rn(val.w.x, val.w.y); return dst; - }, - DEVICE) + })) #endif /* defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 */ + +// quant utils +// fp8 -> half raw +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint8_t, uint16_t, DEVICE, STMTS_WRAPPER({ + __half_raw res = __nv_cvt_fp8_to_halfraw( + val, __NV_E5M2); + return res.x; + })) + +// fp8x2 -> half2 raw +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint32_t, DEVICE, STMTS_WRAPPER({ + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + __half2_raw res = + __nv_cvt_fp8x2_to_halfraw2( + val, __NV_E5M2); + tmp.u16[0] = res.x; + tmp.u16[1] = res.y; + return tmp.u32; + })) + +// fp8x4 -> half2x2 raw +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint32_t, uint2, DEVICE, STMTS_WRAPPER({ + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = + CastFunctor()(static_cast(val)); + tmp.u32[1] = + CastFunctor()(static_cast(val >> 16U)); + return tmp.u32x2; + })) + +// fp8x8 -> half2x4 raw +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint2, uint4, DEVICE, STMTS_WRAPPER({ + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = CastFunctor()(val.x); + tmp.u64[1] = CastFunctor()(val.y); + return tmp.u64x2; + })) + +// fp8 -> half +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint8_t, half, DEVICE, STMTS_WRAPPER({ + __half_raw res = __nv_cvt_fp8_to_halfraw( + val, __NV_E5M2); + return half(res); + })) + +// fp8x2 -> half2 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, half2, DEVICE, STMTS_WRAPPER({ + __half2_raw res = + __nv_cvt_fp8x2_to_halfraw2( + val, __NV_E5M2); + return half2(res); + })) + +// fp8x4 -> half4 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint32_t, dtype::half4, DEVICE, STMTS_WRAPPER({ + half2 tmp1, tmp2; + tmp1 = CastFunctor()(static_cast(val)); + tmp2 = CastFunctor()(static_cast(val >> 16U)); + dtype::half4 res; + res.x = tmp1; + res.y = tmp2; + return res; + })) + +// fp8x8 -> half8 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint2, dtype::half8, DEVICE, STMTS_WRAPPER({ + dtype::half4 tmp1, tmp2; + tmp1 = CastFunctor()(val.x); + tmp2 = CastFunctor()(val.y); + dtype::half8 res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; + })) + +// fp8 -> __nv_bfloat16 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint8_t, __nv_bfloat16, DEVICE, STMTS_WRAPPER({ + // Note there is no direct convert function from fp8 to bf16. + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(val, __NV_E5M2); + // half -> float -> bf16 + float tmp; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(tmp) : "h"(res.x)); + return __float2bfloat16(tmp); + })) + +// fp8x2 -> __nv_bfloat162 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint16_t, __nv_bfloat162, DEVICE, STMTS_WRAPPER({ + __nv_bfloat162 res; + res.x = CastFunctor()(static_cast(val)); + res.y = CastFunctor()( + static_cast(val >> 8U)); + return res; + })) + +// fp8x4 -> bfloat164 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint32_t, dtype::bfloat164, DEVICE, STMTS_WRAPPER({ + dtype::bfloat164 res; + res.x = + CastFunctor()(static_cast(val)); + res.y = CastFunctor()( + static_cast(val >> 16U)); + return res; + })) + +// fp8x8 -> bfloat168 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint2, dtype::bfloat168, DEVICE, STMTS_WRAPPER({ + dtype::bfloat164 tmp1, tmp2; + tmp1 = CastFunctor()(val.x); + tmp2 = CastFunctor()(val.y); + dtype::bfloat168 res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; + })) + +// fp8 -> float +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint8_t, float, DEVICE, STMTS_WRAPPER({ + // fp8 -> half + uint16_t tmp = CastFunctor()(val); + // half -> float + float res; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(res) : "h"(tmp)); + return res; + })) + +// fp8x2 -> float2 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint16_t, float2, DEVICE, STMTS_WRAPPER({ + // fp8x2 -> half2 + uint32_t tmp = CastFunctor()(val); + // half2 -> float2 + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(tmp)); + float lof, hif; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(lof) : "h"(lo)); + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(hif) : "h"(hi)); + return make_float2(lof, hif); + })) + +// fp8x4 -> float4_ +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint32_t, dtype::float4_, DEVICE, STMTS_WRAPPER({ + dtype::float4_ res; + res.x = CastFunctor()(static_cast(val)); + res.y = + CastFunctor()(static_cast(val >> 16U)); + return res; + })) + +// fp8x8 -> float8_ +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint2, dtype::float8_, DEVICE, STMTS_WRAPPER({ + dtype::float4_ tmp1, tmp2; + tmp1 = CastFunctor()(val.x); + tmp2 = CastFunctor()(val.y); + dtype::float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; + })) + +// half -> fp8 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint8_t, DEVICE, STMTS_WRAPPER({ + __half_raw tmp; + tmp.x = val; + __nv_fp8_storage_t res = + __nv_cvt_halfraw_to_fp8( + tmp, __NV_SATFINITE, __NV_E5M2); + return static_cast(res); + })) + +// bf16 -> fp8 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, uint8_t, DEVICE, + STMTS_WRAPPER({ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + __nv_fp8_storage_t res = + __nv_cvt_bfloat16raw_to_fp8( + __nv_bfloat16_raw(val), + __NV_SATFINITE, __NV_E5M2); + return static_cast(res); +#endif + })) + +// float -> fp8 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, uint8_t, DEVICE, STMTS_WRAPPER({ + __nv_fp8_storage_t res = + __nv_cvt_float_to_fp8( + val, __NV_SATFINITE, __NV_E5M2); + return static_cast(res); + })) + +// fp8x4 -> float4 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint32_t, float4, DEVICE, STMTS_WRAPPER({ + dtype::float4_ tmp = CastFunctor()(val); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; + })) + +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, uint32_t, DEVICE, STMTS_WRAPPER({ + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(val); + return uint32; + })) + +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, uint2, DEVICE, + STMTS_WRAPPER({ + uint2 b; + float2 c; + c.x = val.x.x; + c.y = val.x.y; + b.x = CastFunctor()(c); + + c.x = val.y.x; + c.y = val.y.y; + b.y = CastFunctor()(c); + + return b; + })) + +// float4_ -> float4 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, float4, DEVICE, + STMTS_WRAPPER({ + float4 b; + b.x = val.x.x; + b.y = val.x.y; + b.z = val.y.x; + b.w = val.y.y; + return b; + })) + +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + dtype::float8_, uint4, DEVICE, STMTS_WRAPPER({ + uint4 b; + b.x = CastFunctor()(val.x); + b.y = CastFunctor()(val.y); + b.z = CastFunctor()(val.z); + b.w = CastFunctor()(val.w); + return b; + })) + #endif /* defined(COLOSSAL_WITH_CUDA) */ +#undef STMTS_WRAPPER #undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION } // namespace funcs } // namespace colossalAI diff --git a/extensions/csrc/funcs/unary_functor.h b/extensions/csrc/funcs/unary_functor.h index e1d23792a..ea75018df 100644 --- a/extensions/csrc/funcs/unary_functor.h +++ b/extensions/csrc/funcs/unary_functor.h @@ -15,21 +15,6 @@ namespace colossalAI { namespace funcs { -template -inline __device__ void zero(T& dst) { - constexpr int WORDS = sizeof(T) / 4; - union { - T raw; - uint32_t words[WORDS]; - } tmp; - -#pragma unroll - for (int ii = 0; ii < WORDS; ii++) { - tmp.words[ii] = 0u; - } - dst = tmp.raw; -} - // Note(LiuYang): As a retrieved table to check which operation is supported // already enum class UnaryOpType { kLog2Ceil = 0, kAbs, kSum }; diff --git a/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu b/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu index 6e05434b8..9b3a8261e 100644 --- a/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu @@ -174,13 +174,13 @@ void context_kv_cache_memcpy( key.scalar_type(), "context_kv_cache_memcpy", apply_context_kv_cache_memcpy( - key, - value, - key_cache, - value_cache, - sequence_lengths, - cu_seqlens, - block_tables, - max_seq_len_in_batch - );) + key, + value, + key_cache, + value_cache, + sequence_lengths, + cu_seqlens, + block_tables, + max_seq_len_in_batch + );) } diff --git a/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu b/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu index a004a98c3..9e933ff2a 100644 --- a/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu +++ b/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu @@ -5,7 +5,6 @@ #include #include #include -#include #include "common/micros.h" #include "funcs/cast_functor.h" @@ -34,11 +33,25 @@ constexpr unsigned int nextHighestPowerOf2(unsigned int v) { return v; } +template +inline __device__ void zero(T& dst) { + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; + +#pragma unroll + for (int ii = 0; ii < WORDS; ii++) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + using colossalAI::funcs::BinaryOpType; using colossalAI::funcs::CastFunctor; using colossalAI::funcs::TernaryOpFunctor; using colossalAI::funcs::TernaryOpType; -using colossalAI::funcs::zero; using colossalAI::common::VecTypeTrait; using colossalAI::common::FloatVecTypeTrait; using namespace colossalAI::cuda::attention; @@ -84,10 +97,12 @@ __global__ void flash_decoding_attention_kernel( constexpr int NUM_ROWS_PER_ROUNDS = MIN(WARP_SIZE / NUM_THREADS_PER_X, BLOCK_SIZE); constexpr int NUM_VECS_PER_THREAD = NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN / WARP_SIZE; - using K_vec = typename VecTypeTrait::Type; - using V_vec = typename VecTypeTrait::Type; - using L_vec = typename VecTypeTrait::Type; - using Float_vec = typename FloatVecTypeTrait::Type; + using KVecT = typename VecTypeTrait::Type; + using VVecT = typename VecTypeTrait::Type; + using KQuantVecT = typename VecTypeTrait::Type; + using VQuantVecT = typename VecTypeTrait::Type; + using LVecT = typename VecTypeTrait::Type; + using FloatVecT = typename FloatVecTypeTrait::Type; const int context_len = context_lens[seq_idx]; const int thread_group_offset = lane % NUM_THREADS_PER_X; @@ -119,18 +134,18 @@ __global__ void flash_decoding_attention_kernel( scalar_t* q_shared_ptr = reinterpret_cast(q_shared); // each warp access a whole block - K_vec q_vecs[NUM_VECS_PER_THREAD]; + KVecT q_vecs[NUM_VECS_PER_THREAD]; #pragma unroll for (int idx = lane, i = 0; idx < NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN; idx += WARP_SIZE, i += 1) { const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS; const int offset1 = idx % NUM_THREADS_PER_X; - q_vecs[i] = *reinterpret_cast(q_shared_ptr + offset0 * x + offset1 * VEC_SIZE); + q_vecs[i] = *reinterpret_cast(q_shared_ptr + offset0 * x + offset1 * VEC_SIZE); } for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) { const int64_t physical_block_number = static_cast(block_table_shared[block_idx]); - K_vec k_vecs[NUM_VECS_PER_THREAD]; + KVecT k_vecs[NUM_VECS_PER_THREAD]; #pragma unroll for (int i = 0; i < BLOCK_SIZE; i += NUM_ROWS_PER_ROUNDS) { @@ -142,7 +157,7 @@ __global__ void flash_decoding_attention_kernel( const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS; const int offset1 = (idx / NUM_THREADS_PER_X) % NUM_ROWS_PER_ROUNDS; const int offset2 = idx % NUM_THREADS_PER_X; - k_vecs[j] = *reinterpret_cast(k_ptr + offset0 * BLOCK_SIZE * x + offset1 * x + offset2 * VEC_SIZE); + k_vecs[j] = CastFunctor()(*reinterpret_cast(k_ptr + offset0 * BLOCK_SIZE * x + offset1 * x + offset2 * VEC_SIZE)); } float qk = scale * Qk_dot::dot(q_vecs, k_vecs); @@ -174,13 +189,13 @@ __global__ void flash_decoding_attention_kernel( } __syncthreads(); - Float_vec accs[NUM_ROUNDS_PER_TOKEN]; + FloatVecT accs[NUM_ROUNDS_PER_TOKEN]; #pragma unroll for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { zero(accs[i]); } - V_vec zero_value; + VVecT zero_value; zero(zero_value); for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) { const int64_t physical_block_number = static_cast(block_table_shared[block_idx]); @@ -193,11 +208,11 @@ __global__ void flash_decoding_attention_kernel( + kv_head_idx * kv_head_stride + idx * VEC_SIZE; - V_vec v_vecs[NUM_ROUNDS_PER_TOKEN]; + VVecT v_vecs[NUM_ROUNDS_PER_TOKEN]; #pragma unroll for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { - v_vecs[i] = (reinterpret_cast(v_ptr))[i * WARP_SIZE]; + v_vecs[i] = CastFunctor()(*((reinterpret_cast(v_ptr) + i * WARP_SIZE))); } if (token_idx >= context_len) { @@ -210,7 +225,7 @@ __global__ void flash_decoding_attention_kernel( logit = CastFunctor()(logits[token_idx]); #pragma unroll for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { - accs[i] = TernaryOpFunctor()(logit, v_vecs[i], accs[i]); + accs[i] = TernaryOpFunctor()(logit, v_vecs[i], accs[i]); } } } @@ -220,16 +235,16 @@ __global__ void flash_decoding_attention_kernel( #pragma unroll for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { - block_sum(out_shared_mem, accs[i]); + block_sum(out_shared_mem, accs[i]); } scalar_t* out_ptr = out + seq_idx * q_stride + head_idx * HEAD_SIZE; - L_vec out_reg; + LVecT out_reg; #pragma unroll for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { if (thread_idx < NUM_THREADS_PER_TOKEN) { - out_reg = CastFunctor()(accs[i]); - (reinterpret_cast(out_ptr))[thread_idx + i * NUM_THREADS_PER_TOKEN] = out_reg; + out_reg = CastFunctor()(accs[i]); + (reinterpret_cast(out_ptr))[thread_idx + i * NUM_THREADS_PER_TOKEN] = out_reg; } } } @@ -353,18 +368,40 @@ void flash_decoding_attention( torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size] torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions] float scale) { - switch (query.scalar_type()) { - case at::ScalarType::Float: - CALL_V1_LAUNCHER_BLOCK_SIZE(float, float); - break; - case at::ScalarType::Half: - CALL_V1_LAUNCHER_BLOCK_SIZE(half, half); - break; - case at::ScalarType::BFloat16: - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16); - break; - default: - AT_ERROR("Unsupported data type: ", toString(query.scalar_type())); + + + TORCH_CHECK(query.scalar_type() == at::ScalarType::Float || query.scalar_type() == at::ScalarType::Half || query.scalar_type() == at::ScalarType::BFloat16, + "Dtype of query should be float, half or bfloat16!"); + TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == key_cache.scalar_type(), + "Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!"); + + if(key_cache.scalar_type() == at::ScalarType::Byte) + { + switch (query.scalar_type()) { + case at::ScalarType::Float: + CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t); + break; + case at::ScalarType::Half: + CALL_V1_LAUNCHER_BLOCK_SIZE(half, uint8_t); + break; + case at::ScalarType::BFloat16: + CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t); + break; + } + } + else + { + switch (query.scalar_type()) { + case at::ScalarType::Float: + CALL_V1_LAUNCHER_BLOCK_SIZE(float, float); + break; + case at::ScalarType::Half: + CALL_V1_LAUNCHER_BLOCK_SIZE(half, half); + break; + case at::ScalarType::BFloat16: + CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16); + break; + } } }