mirror of https://github.com/hpcaitech/ColossalAI
[Inference] Remove unnecessary float4_ and rename float8_ to float8 (#5679)
parent
537a3cbc4d
commit
725fbd2ed0
|
@ -40,14 +40,7 @@ struct half8 {
|
|||
#endif
|
||||
};
|
||||
|
||||
struct float4_ {
|
||||
#ifdef COLOSSAL_WITH_CUDA
|
||||
float2 x;
|
||||
float2 y;
|
||||
#endif
|
||||
};
|
||||
|
||||
struct float8_ {
|
||||
struct float8 {
|
||||
#ifdef COLOSSAL_WITH_CUDA
|
||||
float2 x;
|
||||
float2 y;
|
||||
|
|
|
@ -49,7 +49,7 @@ VEC_TYPE_TRAITS_SPECIALIZATION(half, 4, dtype::half4);
|
|||
VEC_TYPE_TRAITS_SPECIALIZATION(half, 8, dtype::half8);
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8_)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8)
|
||||
#endif /* defined(COLOSSAL_WITH_CUDA) */
|
||||
|
||||
#undef VEC_TYPE_TRAITS_SPECIALIZATION
|
||||
|
@ -64,11 +64,11 @@ VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8_)
|
|||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(float2, float2)
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(float4, float4)
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat162, float2);
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::bfloat164, dtype::float4_);
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::bfloat168, dtype::float8_);
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::bfloat164, float4);
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::bfloat168, dtype::float8);
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(half2, float2);
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::half4, dtype::float4_);
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::half8, dtype::float8_);
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::half4, float4);
|
||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::half8, dtype::float8);
|
||||
#endif /* COLOSSAL_WITH_CUDA */
|
||||
|
||||
#undef FLOATVEC_TYPE_TRAITS_SPECIALIZATION
|
||||
|
|
|
@ -164,22 +164,22 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
|||
return mul(fa, fb);
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::bfloat164, dtype::bfloat164, dtype::float4_, BinaryOpType::kMul,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ fc;
|
||||
BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
BinaryOpType::kMul>
|
||||
mul;
|
||||
fc.x = mul(lhs.x, rhs.x);
|
||||
fc.y = mul(lhs.y, rhs.y);
|
||||
return fc;
|
||||
}))
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(dtype::bfloat164, dtype::bfloat164,
|
||||
float4, BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float4 fc;
|
||||
CastFunctor<__nv_bfloat16, float> cast;
|
||||
fc.x = cast(lhs.x.x) * cast(rhs.x.x);
|
||||
fc.y = cast(lhs.x.y) * cast(rhs.x.y);
|
||||
fc.z = cast(lhs.y.x) * cast(rhs.y.x);
|
||||
fc.w = cast(lhs.y.y) * cast(rhs.y.y);
|
||||
return fc;
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::bfloat168, dtype::bfloat168, dtype::float8_, BinaryOpType::kMul,
|
||||
dtype::bfloat168, dtype::bfloat168, dtype::float8, BinaryOpType::kMul,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
dtype::float8_ fc;
|
||||
dtype::float8 fc;
|
||||
BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
BinaryOpType::kMul>
|
||||
mul;
|
||||
|
@ -199,20 +199,22 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
|||
return mul(fa, fb);
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::half4, dtype::half4, dtype::float4_, BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float4_ fc;
|
||||
BinaryOpFunctor<half2, half2, float2, BinaryOpType::kMul> mul;
|
||||
fc.x = mul(lhs.x, rhs.x);
|
||||
fc.y = mul(lhs.y, rhs.y);
|
||||
return fc;
|
||||
}))
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(dtype::half4, dtype::half4, float4,
|
||||
BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float4 fc;
|
||||
CastFunctor<half, float> cast;
|
||||
fc.x = cast(lhs.x.x) * cast(rhs.x.x);
|
||||
fc.y = cast(lhs.x.y) * cast(rhs.x.y);
|
||||
fc.z = cast(lhs.y.x) * cast(rhs.y.x);
|
||||
fc.w = cast(lhs.y.y) * cast(rhs.y.y);
|
||||
return fc;
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::half8, dtype::half8, dtype::float8_, BinaryOpType::kMul, DEVICE,
|
||||
dtype::half8, dtype::half8, dtype::float8, BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float8_ fc;
|
||||
dtype::float8 fc;
|
||||
BinaryOpFunctor<half2, half2, float2, BinaryOpType::kMul> mul;
|
||||
fc.x = mul(lhs.x, rhs.x);
|
||||
fc.y = mul(lhs.y, rhs.y);
|
||||
|
|
|
@ -69,14 +69,16 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::half4, DEVICE,
|
|||
dst.y = __floats2half2_rn(val.z, val.w);
|
||||
return dst;
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, dtype::half4, DEVICE,
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::half4, float4, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::half4 dst;
|
||||
dst.x = __float22half2_rn(val.x);
|
||||
dst.y = __float22half2_rn(val.y);
|
||||
float4 dst;
|
||||
dst.x = __half2float(val.x.x);
|
||||
dst.y = __half2float(val.x.y);
|
||||
dst.z = __half2float(val.y.x);
|
||||
dst.w = __half2float(val.y.y);
|
||||
return dst;
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8_, dtype::half8, DEVICE,
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8, dtype::half8, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::half8 dst;
|
||||
dst.x = __float22half2_rn(val.x);
|
||||
|
@ -107,6 +109,15 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::bfloat164, DEVICE,
|
|||
__floats2bfloat162_rn(val.z, val.w);
|
||||
return dst;
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::bfloat164, float4, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float4 dst;
|
||||
dst.x = __bfloat162float(val.x.x);
|
||||
dst.y = __bfloat162float(val.x.y);
|
||||
dst.z = __bfloat162float(val.y.x);
|
||||
dst.w = __bfloat162float(val.y.y);
|
||||
return dst;
|
||||
}))
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat162, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
|
@ -120,14 +131,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, __nv_bfloat162, DEVICE,
|
|||
STMTS_WRAPPER({
|
||||
return __float22bfloat162_rn(val);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, dtype::bfloat164, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::bfloat164 dst;
|
||||
dst.x = __float22bfloat162_rn(val.x);
|
||||
dst.y = __float22bfloat162_rn(val.y);
|
||||
return dst;
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8_, dtype::bfloat168, DEVICE,
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8, dtype::bfloat168, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::bfloat168 dst;
|
||||
dst.x = __float22bfloat162_rn(val.x);
|
||||
|
@ -155,14 +159,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, __nv_bfloat162, DEVICE,
|
|||
val.y);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float4_, dtype::bfloat164, DEVICE, STMTS_WRAPPER({
|
||||
dtype::bfloat164 dst;
|
||||
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
|
||||
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
|
||||
return dst;
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float8_, dtype::bfloat168, DEVICE, STMTS_WRAPPER({
|
||||
dtype::float8, dtype::bfloat168, DEVICE, STMTS_WRAPPER({
|
||||
dtype::bfloat168 dst;
|
||||
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
|
||||
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
|
||||
|
@ -405,35 +402,27 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint32_t, DEVICE, STMTS_WRAPPER({
|
|||
(b << 8U) | a;
|
||||
}))
|
||||
|
||||
// fp8x4 -> float4_
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint32_t, dtype::float4_, DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ res;
|
||||
res.x = CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val));
|
||||
res.y =
|
||||
CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val >> 16U));
|
||||
return res;
|
||||
}))
|
||||
|
||||
// fp8x4 -> float4
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint32_t, float4, DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ tmp = CastFunctor<uint32_t, dtype::float4_>()(val);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
float4 res;
|
||||
res.x = CastFunctor<uint8_t, float>()(static_cast<uint8_t>(val));
|
||||
res.y = CastFunctor<uint8_t, float>()(static_cast<uint8_t>(val >> 8U));
|
||||
res.z = CastFunctor<uint8_t, float>()(static_cast<uint8_t>(val >> 16U));
|
||||
res.w = CastFunctor<uint8_t, float>()(static_cast<uint8_t>(val >> 24U));
|
||||
return res;
|
||||
}))
|
||||
|
||||
// fp8x8 -> float8_
|
||||
// fp8x8 -> float8
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint2, dtype::float8_, DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ tmp1, tmp2;
|
||||
tmp1 = CastFunctor<uint32_t, dtype::float4_>()(val.x);
|
||||
tmp2 = CastFunctor<uint32_t, dtype::float4_>()(val.y);
|
||||
dtype::float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
uint2, dtype::float8, DEVICE, STMTS_WRAPPER({
|
||||
dtype::float8 res;
|
||||
res.x = CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val.x));
|
||||
res.y =
|
||||
CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val.x >> 16U));
|
||||
res.z = CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val.y));
|
||||
res.w =
|
||||
CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val.y >> 16U));
|
||||
return res;
|
||||
}))
|
||||
|
||||
|
@ -482,34 +471,22 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, uint32_t, DEVICE, STMTS_WRAPPER({
|
|||
return uint32;
|
||||
}))
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, uint2, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint2, DEVICE, STMTS_WRAPPER({
|
||||
uint2 b;
|
||||
float2 c;
|
||||
c.x = val.x.x;
|
||||
c.y = val.x.y;
|
||||
c.x = val.x;
|
||||
c.y = val.y;
|
||||
b.x = CastFunctor<float2, uint32_t>()(c);
|
||||
|
||||
c.x = val.y.x;
|
||||
c.y = val.y.y;
|
||||
c.x = val.z;
|
||||
c.y = val.w;
|
||||
b.y = CastFunctor<float2, uint32_t>()(c);
|
||||
|
||||
return b;
|
||||
}))
|
||||
|
||||
// float4_ -> float4
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, float4, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float4 b;
|
||||
b.x = val.x.x;
|
||||
b.y = val.x.y;
|
||||
b.z = val.y.x;
|
||||
b.w = val.y.y;
|
||||
return b;
|
||||
}))
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float8_, uint4, DEVICE, STMTS_WRAPPER({
|
||||
dtype::float8, uint4, DEVICE, STMTS_WRAPPER({
|
||||
uint4 b;
|
||||
b.x = CastFunctor<float2, uint32_t>()(val.x);
|
||||
b.y = CastFunctor<float2, uint32_t>()(val.y);
|
||||
|
|
|
@ -94,29 +94,27 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
|||
return fma(cast(a), b, c);
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::half4, dtype::half4, dtype::float4_, TernaryOpType::kFma, DEVICE,
|
||||
dtype::half4, dtype::half4, float4, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float4_ fd;
|
||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||
fd.x = fma(a.x, b.x, c.x);
|
||||
fd.y = fma(a.y, b.y, c.y);
|
||||
float4 fd;
|
||||
CastFunctor<dtype::half4, float4> cast;
|
||||
TernaryOpFunctor<float4, float4, float4, TernaryOpType::kFma> fma;
|
||||
fd = fma(cast(a), cast(b), c);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
half, dtype::half4, dtype::float4_, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float4_ fd;
|
||||
CastFunctor<half, half2> cast;
|
||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||
half2 s = cast(a);
|
||||
fd.x = fma(s, b.x, c.x);
|
||||
fd.y = fma(s, b.y, c.y);
|
||||
half, dtype::half4, float4, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
|
||||
float4 fd;
|
||||
CastFunctor<half, float> cast0;
|
||||
CastFunctor<dtype::half4, float4> cast1;
|
||||
TernaryOpFunctor<float, float4, float4, TernaryOpType::kFma> fma;
|
||||
fd = fma(cast0(a), cast1(b), c);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::half8, dtype::half8, dtype::float8_, TernaryOpType::kFma, DEVICE,
|
||||
dtype::half8, dtype::half8, dtype::float8, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float8_ fd;
|
||||
dtype::float8 fd;
|
||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||
fd.x = fma(a.x, b.x, c.x);
|
||||
fd.y = fma(a.y, b.y, c.y);
|
||||
|
@ -125,9 +123,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
|||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
half, dtype::half8, dtype::float8_, TernaryOpType::kFma, DEVICE,
|
||||
half, dtype::half8, dtype::float8, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float8_ fd;
|
||||
dtype::float8 fd;
|
||||
CastFunctor<half, half2> cast;
|
||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||
half2 s = cast(a);
|
||||
|
@ -160,33 +158,28 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
|||
return fma(cast(a), b, c);
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::bfloat164, dtype::bfloat164, dtype::float4_, TernaryOpType::kFma,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ fd;
|
||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
TernaryOpType::kFma>
|
||||
fma;
|
||||
fd.x = fma(a.x, b.x, c.x);
|
||||
fd.y = fma(a.y, b.y, c.y);
|
||||
dtype::bfloat164, dtype::bfloat164, float4, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float4 fd;
|
||||
CastFunctor<dtype::bfloat164, float4> cast;
|
||||
TernaryOpFunctor<float4, float4, float4, TernaryOpType::kFma> fma;
|
||||
fd = fma(cast(a), cast(b), c);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, dtype::bfloat164, dtype::float4_, TernaryOpType::kFma,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ fd;
|
||||
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast;
|
||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
TernaryOpType::kFma>
|
||||
fma;
|
||||
__nv_bfloat162 s = cast(a);
|
||||
fd.x = fma(s, b.x, c.x);
|
||||
fd.y = fma(s, b.y, c.y);
|
||||
__nv_bfloat16, dtype::bfloat164, float4, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float4 fd;
|
||||
CastFunctor<__nv_bfloat16, float> cast0;
|
||||
CastFunctor<dtype::bfloat164, float4> cast1;
|
||||
TernaryOpFunctor<float, float4, float4, TernaryOpType::kFma> fma;
|
||||
fd = fma(cast0(a), cast1(b), c);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::bfloat168, dtype::bfloat168, dtype::float8_, TernaryOpType::kFma,
|
||||
dtype::bfloat168, dtype::bfloat168, dtype::float8, TernaryOpType::kFma,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
dtype::float8_ fd;
|
||||
dtype::float8 fd;
|
||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
TernaryOpType::kFma>
|
||||
fma;
|
||||
|
@ -197,9 +190,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
|||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, dtype::bfloat168, dtype::float8_, TernaryOpType::kFma,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
dtype::float8_ fd;
|
||||
__nv_bfloat16, dtype::bfloat168, dtype::float8, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float8 fd;
|
||||
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast;
|
||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
TernaryOpType::kFma>
|
||||
|
|
|
@ -52,13 +52,7 @@ COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float2, float, UnaryOpType::kSum, DEVICE,
|
|||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float4, float, UnaryOpType::kSum, DEVICE,
|
||||
{ return val.x + val.y + val.z + val.w; })
|
||||
|
||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(dtype::float4_, float, UnaryOpType::kSum,
|
||||
DEVICE, {
|
||||
return val.x.x + val.x.y + val.y.x +
|
||||
val.y.y;
|
||||
})
|
||||
|
||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(dtype::float8_, float, UnaryOpType::kSum,
|
||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(dtype::float8, float, UnaryOpType::kSum,
|
||||
DEVICE, {
|
||||
return val.x.x + val.x.y + val.y.x +
|
||||
val.y.y + val.z.x + val.z.y +
|
||||
|
|
|
@ -283,11 +283,14 @@ void rms_layernorm(
|
|||
case 4:
|
||||
RMSNORM_LAUNCHER(4, block);
|
||||
break;
|
||||
case 5:
|
||||
RMSNORM_LAUNCHER(5, block);
|
||||
break;
|
||||
case 8:
|
||||
RMSNORM_LAUNCHER(8, block);
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("unroll_factor must be 1, 2, 3, 4 or 8");
|
||||
AT_ERROR("unroll_factor must be 1, 2, 3, 4, 5 or 8");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -330,11 +333,14 @@ void fused_add_rms_layernorm(
|
|||
case 4:
|
||||
FUSED_ADD_RMSNORM_LAUNCHER(4, block);
|
||||
break;
|
||||
case 5:
|
||||
FUSED_ADD_RMSNORM_LAUNCHER(5, block);
|
||||
break;
|
||||
case 8:
|
||||
FUSED_ADD_RMSNORM_LAUNCHER(8, block);
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("unroll_factor must be 1, 2, 3, 4 or 8");
|
||||
AT_ERROR("unroll_factor must be 1, 2, 3, 4, 5 or 8");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue