#pragma once #if defined(COLOSSAL_WITH_CUDA) #include #include #include #include #endif #include #include #include "cast_functor.h" #include "common/micros.h" namespace colossalAI { namespace funcs { enum class TernaryOpType { kFma = 0 }; template struct TernaryOpFunctor; #define STMTS_WRAPPER(...) __VA_ARGS__ #define COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( \ LT, RT, RET, TERNARY_OP_TYPE, FUNCTION_MODIFIER, STMTS, ARGS...) \ template \ struct TernaryOpFunctor { \ FUNCTION_MODIFIER RET operator()(LT a, RT b, RET c) STMTS \ }; #if defined(COLOSSAL_WITH_CUDA) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float, float, float, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ float d; d = fma(a, b, c); return d; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float2, float2, float2, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ float2 d; d.x = fma(a.x, b.x, c.x); d.y = fma(a.y, b.y, c.y); return d; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float, float2, float2, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ float2 d; d.x = fma(a, b.x, c.x); d.y = fma(a, b.y, c.y); return d; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float4, float4, float4, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ float4 d; d.x = fma(a.x, b.x, c.x); d.y = fma(a.y, b.y, c.y); d.z = fma(a.z, b.z, c.z); d.w = fma(a.w, b.w, c.w); return d; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float, float4, float4, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ float4 d; d.x = fma(a, b.x, c.x); d.y = fma(a, b.y, c.y); d.z = fma(a, b.z, c.z); d.w = fma(a, b.w, c.w); return d; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( half, half, float, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ return __half2float(a) * __half2float(b) + c; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( half2, half2, float2, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ CastFunctor cast; TernaryOpFunctor fma; float2 fa = cast(a); float2 fb = cast(b); return fma(fa, fb, c); })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( half, half2, float2, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ CastFunctor cast; TernaryOpFunctor fma; return fma(cast(a), b, c); })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( dtype::half4, dtype::half4, float4, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ float4 fd; CastFunctor cast; TernaryOpFunctor fma; fd = fma(cast(a), cast(b), c); return fd; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( half, dtype::half4, float4, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ float4 fd; CastFunctor cast0; CastFunctor cast1; TernaryOpFunctor fma; fd = fma(cast0(a), cast1(b), c); return fd; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( dtype::half8, dtype::half8, dtype::float8, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ dtype::float8 fd; TernaryOpFunctor fma; fd.x = fma(a.x, b.x, c.x); fd.y = fma(a.y, b.y, c.y); fd.z = fma(a.z, b.z, c.z); fd.w = fma(a.w, b.w, c.w); return fd; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( half, dtype::half8, dtype::float8, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ dtype::float8 fd; CastFunctor cast; TernaryOpFunctor fma; half2 s = cast(a); fd.x = fma(s, b.x, c.x); fd.y = fma(s, b.y, c.y); fd.z = fma(s, b.z, c.z); fd.w = fma(s, b.w, c.w); return fd; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( __nv_bfloat16, __nv_bfloat16, float, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ return __bfloat162float(a) * __bfloat162float(b) + c; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( __nv_bfloat162, __nv_bfloat162, float2, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ CastFunctor<__nv_bfloat162, float2> cast; TernaryOpFunctor fma; float2 fa = cast(a); float2 fb = cast(b); return fma(fa, fb, c); })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( __nv_bfloat16, __nv_bfloat162, float2, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ CastFunctor<__nv_bfloat16, __nv_bfloat162> cast; TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, TernaryOpType::kFma> fma; return fma(cast(a), b, c); })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( dtype::bfloat164, dtype::bfloat164, float4, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ float4 fd; CastFunctor cast; TernaryOpFunctor fma; fd = fma(cast(a), cast(b), c); return fd; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( __nv_bfloat16, dtype::bfloat164, float4, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ float4 fd; CastFunctor<__nv_bfloat16, float> cast0; CastFunctor cast1; TernaryOpFunctor fma; fd = fma(cast0(a), cast1(b), c); return fd; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( dtype::bfloat168, dtype::bfloat168, dtype::float8, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ dtype::float8 fd; TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, TernaryOpType::kFma> fma; fd.x = fma(a.x, b.x, c.x); fd.y = fma(a.y, b.y, c.y); fd.z = fma(a.z, b.z, c.z); fd.w = fma(a.w, b.w, c.w); return fd; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( __nv_bfloat16, dtype::bfloat168, dtype::float8, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ dtype::float8 fd; CastFunctor<__nv_bfloat16, __nv_bfloat162> cast; TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, TernaryOpType::kFma> fma; __nv_bfloat162 s = cast(a); fd.x = fma(s, b.x, c.x); fd.y = fma(s, b.y, c.y); fd.z = fma(s, b.z, c.z); fd.w = fma(s, b.w, c.w); return fd; })) #endif /* defined(COLOSSAL_WITH_CUDA) */ #undef COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION #undef STMTS_WRAPPER } // namespace funcs } // namespace colossalAI