@ -1,18 +1,20 @@
# pragma once
# if defined(COLOSSAL_WITH_CUDA)
# include <cuda.h>
# include <cuda_bf16.h>
# include <cuda_fp16.h>
# include <cuda_runtime.h>
# endif
# include <float.h>
# include <functional>
# include " ../funcs/ cast_functor.h"
# include " ../utils /micros.h"
# include " cast_functor.h"
# include " common /micros.h"
namespace colossalAI {
namespace cuda {
namespace funcs {
enum class TernaryOpType { kFma = 0 } ;
@ -29,6 +31,7 @@ struct TernaryOpFunctor;
FUNCTION_MODIFIER RET operator ( ) ( LT a , RT b , RET c ) STMTS \
} ;
# if defined(COLOSSAL_WITH_CUDA)
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION ( float , float , float ,
TernaryOpType : : kFma , DEVICE ,
STMTS_WRAPPER ( {
@ -91,16 +94,18 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
return fma ( cast ( a ) , b , c ) ;
} ) )
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION (
half4 , half4 , float4_ , TernaryOpType : : kFma , DEVICE , STMTS_WRAPPER ( {
float4_ fd ;
dtype : : half4 , dtype : : half4 , dtype : : float4_ , TernaryOpType : : kFma , DEVICE ,
STMTS_WRAPPER ( {
dtype : : float4_ fd ;
TernaryOpFunctor < half2 , half2 , float2 , TernaryOpType : : kFma > fma ;
fd . x = fma ( a . x , b . x , c . x ) ;
fd . y = fma ( a . y , b . y , c . y ) ;
return fd ;
} ) )
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION (
half , half4 , float4_ , TernaryOpType : : kFma , DEVICE , STMTS_WRAPPER ( {
float4_ fd ;
half , dtype : : half4 , dtype : : float4_ , TernaryOpType : : kFma , DEVICE ,
STMTS_WRAPPER ( {
dtype : : float4_ fd ;
CastFunctor < half , half2 > cast ;
TernaryOpFunctor < half2 , half2 , float2 , TernaryOpType : : kFma > fma ;
half2 s = cast ( a ) ;
@ -109,8 +114,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
return fd ;
} ) )
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION (
half8 , half8 , float8_ , TernaryOpType : : kFma , DEVICE , STMTS_WRAPPER ( {
float8_ fd ;
dtype : : half8 , dtype : : half8 , dtype : : float8_ , TernaryOpType : : kFma , DEVICE ,
STMTS_WRAPPER ( {
dtype : : float8_ fd ;
TernaryOpFunctor < half2 , half2 , float2 , TernaryOpType : : kFma > fma ;
fd . x = fma ( a . x , b . x , c . x ) ;
fd . y = fma ( a . y , b . y , c . y ) ;
@ -119,8 +125,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
return fd ;
} ) )
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION (
half , half8 , float8_ , TernaryOpType : : kFma , DEVICE , STMTS_WRAPPER ( {
float8_ fd ;
half , dtype : : half8 , dtype : : float8_ , TernaryOpType : : kFma , DEVICE ,
STMTS_WRAPPER ( {
dtype : : float8_ fd ;
CastFunctor < half , half2 > cast ;
TernaryOpFunctor < half2 , half2 , float2 , TernaryOpType : : kFma > fma ;
half2 s = cast ( a ) ;
@ -153,8 +160,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
return fma ( cast ( a ) , b , c ) ;
} ) )
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION (
bfloat164 , bfloat164 , float4_ , TernaryOpType : : kFma , DEVICE , STMTS_WRAPPER ( {
float4_ fd ;
dtype : : bfloat164 , dtype : : bfloat164 , dtype : : float4_ , TernaryOpType : : kFma ,
DEVICE , STMTS_WRAPPER ( {
dtype : : float4_ fd ;
TernaryOpFunctor < __nv_bfloat162 , __nv_bfloat162 , float2 ,
TernaryOpType : : kFma >
fma ;
@ -163,9 +171,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
return fd ;
} ) )
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION (
__nv_bfloat16 , bfloat164, float4_, TernaryOpType : : kFma , DEVICE ,
STMTS_WRAPPER ( {
float4_ fd ;
__nv_bfloat16 , dtype: : bfloat164, dtype: : float4_, TernaryOpType : : kFma ,
DEVICE , STMTS_WRAPPER ( {
dtype: : float4_ fd ;
CastFunctor < __nv_bfloat16 , __nv_bfloat162 > cast ;
TernaryOpFunctor < __nv_bfloat162 , __nv_bfloat162 , float2 ,
TernaryOpType : : kFma >
@ -176,8 +184,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
return fd ;
} ) )
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION (
bfloat168 , bfloat168 , float8_ , TernaryOpType : : kFma , DEVICE , STMTS_WRAPPER ( {
float8_ fd ;
dtype : : bfloat168 , dtype : : bfloat168 , dtype : : float8_ , TernaryOpType : : kFma ,
DEVICE , STMTS_WRAPPER ( {
dtype : : float8_ fd ;
TernaryOpFunctor < __nv_bfloat162 , __nv_bfloat162 , float2 ,
TernaryOpType : : kFma >
fma ;
@ -188,9 +197,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
return fd ;
} ) )
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION (
__nv_bfloat16 , bfloat168, float8_, TernaryOpType : : kFma , DEVICE ,
STMTS_WRAPPER ( {
float8_ fd ;
__nv_bfloat16 , dtype: : bfloat168, dtype: : float8_, TernaryOpType : : kFma ,
DEVICE , STMTS_WRAPPER ( {
dtype: : float8_ fd ;
CastFunctor < __nv_bfloat16 , __nv_bfloat162 > cast ;
TernaryOpFunctor < __nv_bfloat162 , __nv_bfloat162 , float2 ,
TernaryOpType : : kFma >
@ -203,10 +212,10 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
return fd ;
} ) )
# undef COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION
# endif /* defined(COLOSSAL_WITH_CUDA) */
# undef COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION
# undef STMTS_WRAPPER
} // namespace funcs
} // namespace cuda
} // namespace colossalAI