|
|
@ -94,29 +94,27 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
return fma(cast(a), b, c); |
|
|
|
return fma(cast(a), b, c); |
|
|
|
})) |
|
|
|
})) |
|
|
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
dtype::half4, dtype::half4, dtype::float4_, TernaryOpType::kFma, DEVICE, |
|
|
|
dtype::half4, dtype::half4, float4, TernaryOpType::kFma, DEVICE, |
|
|
|
STMTS_WRAPPER({ |
|
|
|
STMTS_WRAPPER({ |
|
|
|
dtype::float4_ fd; |
|
|
|
float4 fd; |
|
|
|
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma; |
|
|
|
CastFunctor<dtype::half4, float4> cast; |
|
|
|
fd.x = fma(a.x, b.x, c.x); |
|
|
|
TernaryOpFunctor<float4, float4, float4, TernaryOpType::kFma> fma; |
|
|
|
fd.y = fma(a.y, b.y, c.y); |
|
|
|
fd = fma(cast(a), cast(b), c); |
|
|
|
return fd; |
|
|
|
return fd; |
|
|
|
})) |
|
|
|
})) |
|
|
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
half, dtype::half4, dtype::float4_, TernaryOpType::kFma, DEVICE, |
|
|
|
half, dtype::half4, float4, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ |
|
|
|
STMTS_WRAPPER({ |
|
|
|
float4 fd; |
|
|
|
dtype::float4_ fd; |
|
|
|
CastFunctor<half, float> cast0; |
|
|
|
CastFunctor<half, half2> cast; |
|
|
|
CastFunctor<dtype::half4, float4> cast1; |
|
|
|
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma; |
|
|
|
TernaryOpFunctor<float, float4, float4, TernaryOpType::kFma> fma; |
|
|
|
half2 s = cast(a); |
|
|
|
fd = fma(cast0(a), cast1(b), c); |
|
|
|
fd.x = fma(s, b.x, c.x); |
|
|
|
|
|
|
|
fd.y = fma(s, b.y, c.y); |
|
|
|
|
|
|
|
return fd; |
|
|
|
return fd; |
|
|
|
})) |
|
|
|
})) |
|
|
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
dtype::half8, dtype::half8, dtype::float8_, TernaryOpType::kFma, DEVICE, |
|
|
|
dtype::half8, dtype::half8, dtype::float8, TernaryOpType::kFma, DEVICE, |
|
|
|
STMTS_WRAPPER({ |
|
|
|
STMTS_WRAPPER({ |
|
|
|
dtype::float8_ fd; |
|
|
|
dtype::float8 fd; |
|
|
|
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma; |
|
|
|
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma; |
|
|
|
fd.x = fma(a.x, b.x, c.x); |
|
|
|
fd.x = fma(a.x, b.x, c.x); |
|
|
|
fd.y = fma(a.y, b.y, c.y); |
|
|
|
fd.y = fma(a.y, b.y, c.y); |
|
|
@ -125,9 +123,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
return fd; |
|
|
|
return fd; |
|
|
|
})) |
|
|
|
})) |
|
|
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
half, dtype::half8, dtype::float8_, TernaryOpType::kFma, DEVICE, |
|
|
|
half, dtype::half8, dtype::float8, TernaryOpType::kFma, DEVICE, |
|
|
|
STMTS_WRAPPER({ |
|
|
|
STMTS_WRAPPER({ |
|
|
|
dtype::float8_ fd; |
|
|
|
dtype::float8 fd; |
|
|
|
CastFunctor<half, half2> cast; |
|
|
|
CastFunctor<half, half2> cast; |
|
|
|
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma; |
|
|
|
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma; |
|
|
|
half2 s = cast(a); |
|
|
|
half2 s = cast(a); |
|
|
@ -160,33 +158,28 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
return fma(cast(a), b, c); |
|
|
|
return fma(cast(a), b, c); |
|
|
|
})) |
|
|
|
})) |
|
|
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
dtype::bfloat164, dtype::bfloat164, dtype::float4_, TernaryOpType::kFma, |
|
|
|
dtype::bfloat164, dtype::bfloat164, float4, TernaryOpType::kFma, DEVICE, |
|
|
|
DEVICE, STMTS_WRAPPER({ |
|
|
|
STMTS_WRAPPER({ |
|
|
|
dtype::float4_ fd; |
|
|
|
float4 fd; |
|
|
|
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, |
|
|
|
CastFunctor<dtype::bfloat164, float4> cast; |
|
|
|
TernaryOpType::kFma> |
|
|
|
TernaryOpFunctor<float4, float4, float4, TernaryOpType::kFma> fma; |
|
|
|
fma; |
|
|
|
fd = fma(cast(a), cast(b), c); |
|
|
|
fd.x = fma(a.x, b.x, c.x); |
|
|
|
|
|
|
|
fd.y = fma(a.y, b.y, c.y); |
|
|
|
|
|
|
|
return fd; |
|
|
|
return fd; |
|
|
|
})) |
|
|
|
})) |
|
|
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
__nv_bfloat16, dtype::bfloat164, dtype::float4_, TernaryOpType::kFma, |
|
|
|
__nv_bfloat16, dtype::bfloat164, float4, TernaryOpType::kFma, DEVICE, |
|
|
|
DEVICE, STMTS_WRAPPER({ |
|
|
|
STMTS_WRAPPER({ |
|
|
|
dtype::float4_ fd; |
|
|
|
float4 fd; |
|
|
|
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast; |
|
|
|
CastFunctor<__nv_bfloat16, float> cast0; |
|
|
|
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, |
|
|
|
CastFunctor<dtype::bfloat164, float4> cast1; |
|
|
|
TernaryOpType::kFma> |
|
|
|
TernaryOpFunctor<float, float4, float4, TernaryOpType::kFma> fma; |
|
|
|
fma; |
|
|
|
fd = fma(cast0(a), cast1(b), c); |
|
|
|
__nv_bfloat162 s = cast(a); |
|
|
|
|
|
|
|
fd.x = fma(s, b.x, c.x); |
|
|
|
|
|
|
|
fd.y = fma(s, b.y, c.y); |
|
|
|
|
|
|
|
return fd; |
|
|
|
return fd; |
|
|
|
})) |
|
|
|
})) |
|
|
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
dtype::bfloat168, dtype::bfloat168, dtype::float8_, TernaryOpType::kFma, |
|
|
|
dtype::bfloat168, dtype::bfloat168, dtype::float8, TernaryOpType::kFma, |
|
|
|
DEVICE, STMTS_WRAPPER({ |
|
|
|
DEVICE, STMTS_WRAPPER({ |
|
|
|
dtype::float8_ fd; |
|
|
|
dtype::float8 fd; |
|
|
|
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, |
|
|
|
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, |
|
|
|
TernaryOpType::kFma> |
|
|
|
TernaryOpType::kFma> |
|
|
|
fma; |
|
|
|
fma; |
|
|
@ -197,9 +190,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
return fd; |
|
|
|
return fd; |
|
|
|
})) |
|
|
|
})) |
|
|
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( |
|
|
|
__nv_bfloat16, dtype::bfloat168, dtype::float8_, TernaryOpType::kFma, |
|
|
|
__nv_bfloat16, dtype::bfloat168, dtype::float8, TernaryOpType::kFma, DEVICE, |
|
|
|
DEVICE, STMTS_WRAPPER({ |
|
|
|
STMTS_WRAPPER({ |
|
|
|
dtype::float8_ fd; |
|
|
|
dtype::float8 fd; |
|
|
|
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast; |
|
|
|
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast; |
|
|
|
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, |
|
|
|
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, |
|
|
|
TernaryOpType::kFma> |
|
|
|
TernaryOpType::kFma> |
|
|
|