[Inference] Remove unnecessary float4_ and rename float8_ to float8 (#5679)

pull/5660/merge
Steve Luo 2024-05-06 10:55:34 +08:00 committed by GitHub
parent 537a3cbc4d
commit 725fbd2ed0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 112 additions and 147 deletions

View File

@ -40,14 +40,7 @@ struct half8 {
#endif
};
struct float4_ {
#ifdef COLOSSAL_WITH_CUDA
float2 x;
float2 y;
#endif
};
struct float8_ {
struct float8 {
#ifdef COLOSSAL_WITH_CUDA
float2 x;
float2 y;

View File

@ -49,7 +49,7 @@ VEC_TYPE_TRAITS_SPECIALIZATION(half, 4, dtype::half4);
VEC_TYPE_TRAITS_SPECIALIZATION(half, 8, dtype::half8);
VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2)
VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4)
VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8_)
VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8)
#endif /* defined(COLOSSAL_WITH_CUDA) */
#undef VEC_TYPE_TRAITS_SPECIALIZATION
@ -64,11 +64,11 @@ VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8_)
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(float2, float2)
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(float4, float4)
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat162, float2);
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::bfloat164, dtype::float4_);
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::bfloat168, dtype::float8_);
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::bfloat164, float4);
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::bfloat168, dtype::float8);
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(half2, float2);
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::half4, dtype::float4_);
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::half8, dtype::float8_);
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::half4, float4);
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::half8, dtype::float8);
#endif /* COLOSSAL_WITH_CUDA */
#undef FLOATVEC_TYPE_TRAITS_SPECIALIZATION

View File

@ -164,22 +164,22 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
return mul(fa, fb);
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
dtype::bfloat164, dtype::bfloat164, dtype::float4_, BinaryOpType::kMul,
DEVICE, STMTS_WRAPPER({
dtype::float4_ fc;
BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
BinaryOpType::kMul>
mul;
fc.x = mul(lhs.x, rhs.x);
fc.y = mul(lhs.y, rhs.y);
return fc;
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(dtype::bfloat164, dtype::bfloat164,
float4, BinaryOpType::kMul, DEVICE,
STMTS_WRAPPER({
float4 fc;
CastFunctor<__nv_bfloat16, float> cast;
fc.x = cast(lhs.x.x) * cast(rhs.x.x);
fc.y = cast(lhs.x.y) * cast(rhs.x.y);
fc.z = cast(lhs.y.x) * cast(rhs.y.x);
fc.w = cast(lhs.y.y) * cast(rhs.y.y);
return fc;
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
dtype::bfloat168, dtype::bfloat168, dtype::float8_, BinaryOpType::kMul,
dtype::bfloat168, dtype::bfloat168, dtype::float8, BinaryOpType::kMul,
DEVICE, STMTS_WRAPPER({
dtype::float8_ fc;
dtype::float8 fc;
BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
BinaryOpType::kMul>
mul;
@ -199,20 +199,22 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
return mul(fa, fb);
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
dtype::half4, dtype::half4, dtype::float4_, BinaryOpType::kMul, DEVICE,
STMTS_WRAPPER({
dtype::float4_ fc;
BinaryOpFunctor<half2, half2, float2, BinaryOpType::kMul> mul;
fc.x = mul(lhs.x, rhs.x);
fc.y = mul(lhs.y, rhs.y);
return fc;
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(dtype::half4, dtype::half4, float4,
BinaryOpType::kMul, DEVICE,
STMTS_WRAPPER({
float4 fc;
CastFunctor<half, float> cast;
fc.x = cast(lhs.x.x) * cast(rhs.x.x);
fc.y = cast(lhs.x.y) * cast(rhs.x.y);
fc.z = cast(lhs.y.x) * cast(rhs.y.x);
fc.w = cast(lhs.y.y) * cast(rhs.y.y);
return fc;
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
dtype::half8, dtype::half8, dtype::float8_, BinaryOpType::kMul, DEVICE,
dtype::half8, dtype::half8, dtype::float8, BinaryOpType::kMul, DEVICE,
STMTS_WRAPPER({
dtype::float8_ fc;
dtype::float8 fc;
BinaryOpFunctor<half2, half2, float2, BinaryOpType::kMul> mul;
fc.x = mul(lhs.x, rhs.x);
fc.y = mul(lhs.y, rhs.y);

View File

@ -69,14 +69,16 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::half4, DEVICE,
dst.y = __floats2half2_rn(val.z, val.w);
return dst;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, dtype::half4, DEVICE,
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::half4, float4, DEVICE,
STMTS_WRAPPER({
dtype::half4 dst;
dst.x = __float22half2_rn(val.x);
dst.y = __float22half2_rn(val.y);
float4 dst;
dst.x = __half2float(val.x.x);
dst.y = __half2float(val.x.y);
dst.z = __half2float(val.y.x);
dst.w = __half2float(val.y.y);
return dst;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8_, dtype::half8, DEVICE,
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8, dtype::half8, DEVICE,
STMTS_WRAPPER({
dtype::half8 dst;
dst.x = __float22half2_rn(val.x);
@ -107,6 +109,15 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::bfloat164, DEVICE,
__floats2bfloat162_rn(val.z, val.w);
return dst;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::bfloat164, float4, DEVICE,
STMTS_WRAPPER({
float4 dst;
dst.x = __bfloat162float(val.x.x);
dst.y = __bfloat162float(val.x.y);
dst.z = __bfloat162float(val.y.x);
dst.w = __bfloat162float(val.y.y);
return dst;
}))
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat162, DEVICE,
STMTS_WRAPPER({
@ -120,14 +131,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, __nv_bfloat162, DEVICE,
STMTS_WRAPPER({
return __float22bfloat162_rn(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, dtype::bfloat164, DEVICE,
STMTS_WRAPPER({
dtype::bfloat164 dst;
dst.x = __float22bfloat162_rn(val.x);
dst.y = __float22bfloat162_rn(val.y);
return dst;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8_, dtype::bfloat168, DEVICE,
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8, dtype::bfloat168, DEVICE,
STMTS_WRAPPER({
dtype::bfloat168 dst;
dst.x = __float22bfloat162_rn(val.x);
@ -155,14 +159,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, __nv_bfloat162, DEVICE,
val.y);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
dtype::float4_, dtype::bfloat164, DEVICE, STMTS_WRAPPER({
dtype::bfloat164 dst;
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
return dst;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
dtype::float8_, dtype::bfloat168, DEVICE, STMTS_WRAPPER({
dtype::float8, dtype::bfloat168, DEVICE, STMTS_WRAPPER({
dtype::bfloat168 dst;
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
@ -405,35 +402,27 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint32_t, DEVICE, STMTS_WRAPPER({
(b << 8U) | a;
}))
// fp8x4 -> float4_
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint32_t, dtype::float4_, DEVICE, STMTS_WRAPPER({
dtype::float4_ res;
res.x = CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val));
res.y =
CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val >> 16U));
return res;
}))
// fp8x4 -> float4
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint32_t, float4, DEVICE, STMTS_WRAPPER({
dtype::float4_ tmp = CastFunctor<uint32_t, dtype::float4_>()(val);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
float4 res;
res.x = CastFunctor<uint8_t, float>()(static_cast<uint8_t>(val));
res.y = CastFunctor<uint8_t, float>()(static_cast<uint8_t>(val >> 8U));
res.z = CastFunctor<uint8_t, float>()(static_cast<uint8_t>(val >> 16U));
res.w = CastFunctor<uint8_t, float>()(static_cast<uint8_t>(val >> 24U));
return res;
}))
// fp8x8 -> float8_
// fp8x8 -> float8
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint2, dtype::float8_, DEVICE, STMTS_WRAPPER({
dtype::float4_ tmp1, tmp2;
tmp1 = CastFunctor<uint32_t, dtype::float4_>()(val.x);
tmp2 = CastFunctor<uint32_t, dtype::float4_>()(val.y);
dtype::float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
uint2, dtype::float8, DEVICE, STMTS_WRAPPER({
dtype::float8 res;
res.x = CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val.x));
res.y =
CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val.x >> 16U));
res.z = CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val.y));
res.w =
CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val.y >> 16U));
return res;
}))
@ -482,34 +471,22 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, uint32_t, DEVICE, STMTS_WRAPPER({
return uint32;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, uint2, DEVICE,
STMTS_WRAPPER({
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint2, DEVICE, STMTS_WRAPPER({
uint2 b;
float2 c;
c.x = val.x.x;
c.y = val.x.y;
c.x = val.x;
c.y = val.y;
b.x = CastFunctor<float2, uint32_t>()(c);
c.x = val.y.x;
c.y = val.y.y;
c.x = val.z;
c.y = val.w;
b.y = CastFunctor<float2, uint32_t>()(c);
return b;
}))
// float4_ -> float4
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, float4, DEVICE,
STMTS_WRAPPER({
float4 b;
b.x = val.x.x;
b.y = val.x.y;
b.z = val.y.x;
b.w = val.y.y;
return b;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
dtype::float8_, uint4, DEVICE, STMTS_WRAPPER({
dtype::float8, uint4, DEVICE, STMTS_WRAPPER({
uint4 b;
b.x = CastFunctor<float2, uint32_t>()(val.x);
b.y = CastFunctor<float2, uint32_t>()(val.y);

View File

@ -94,29 +94,27 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
return fma(cast(a), b, c);
}))
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
dtype::half4, dtype::half4, dtype::float4_, TernaryOpType::kFma, DEVICE,
dtype::half4, dtype::half4, float4, TernaryOpType::kFma, DEVICE,
STMTS_WRAPPER({
dtype::float4_ fd;
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
fd.x = fma(a.x, b.x, c.x);
fd.y = fma(a.y, b.y, c.y);
float4 fd;
CastFunctor<dtype::half4, float4> cast;
TernaryOpFunctor<float4, float4, float4, TernaryOpType::kFma> fma;
fd = fma(cast(a), cast(b), c);
return fd;
}))
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
half, dtype::half4, dtype::float4_, TernaryOpType::kFma, DEVICE,
STMTS_WRAPPER({
dtype::float4_ fd;
CastFunctor<half, half2> cast;
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
half2 s = cast(a);
fd.x = fma(s, b.x, c.x);
fd.y = fma(s, b.y, c.y);
half, dtype::half4, float4, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
float4 fd;
CastFunctor<half, float> cast0;
CastFunctor<dtype::half4, float4> cast1;
TernaryOpFunctor<float, float4, float4, TernaryOpType::kFma> fma;
fd = fma(cast0(a), cast1(b), c);
return fd;
}))
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
dtype::half8, dtype::half8, dtype::float8_, TernaryOpType::kFma, DEVICE,
dtype::half8, dtype::half8, dtype::float8, TernaryOpType::kFma, DEVICE,
STMTS_WRAPPER({
dtype::float8_ fd;
dtype::float8 fd;
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
fd.x = fma(a.x, b.x, c.x);
fd.y = fma(a.y, b.y, c.y);
@ -125,9 +123,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
return fd;
}))
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
half, dtype::half8, dtype::float8_, TernaryOpType::kFma, DEVICE,
half, dtype::half8, dtype::float8, TernaryOpType::kFma, DEVICE,
STMTS_WRAPPER({
dtype::float8_ fd;
dtype::float8 fd;
CastFunctor<half, half2> cast;
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
half2 s = cast(a);
@ -160,33 +158,28 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
return fma(cast(a), b, c);
}))
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
dtype::bfloat164, dtype::bfloat164, dtype::float4_, TernaryOpType::kFma,
DEVICE, STMTS_WRAPPER({
dtype::float4_ fd;
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
TernaryOpType::kFma>
fma;
fd.x = fma(a.x, b.x, c.x);
fd.y = fma(a.y, b.y, c.y);
dtype::bfloat164, dtype::bfloat164, float4, TernaryOpType::kFma, DEVICE,
STMTS_WRAPPER({
float4 fd;
CastFunctor<dtype::bfloat164, float4> cast;
TernaryOpFunctor<float4, float4, float4, TernaryOpType::kFma> fma;
fd = fma(cast(a), cast(b), c);
return fd;
}))
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
__nv_bfloat16, dtype::bfloat164, dtype::float4_, TernaryOpType::kFma,
DEVICE, STMTS_WRAPPER({
dtype::float4_ fd;
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast;
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
TernaryOpType::kFma>
fma;
__nv_bfloat162 s = cast(a);
fd.x = fma(s, b.x, c.x);
fd.y = fma(s, b.y, c.y);
__nv_bfloat16, dtype::bfloat164, float4, TernaryOpType::kFma, DEVICE,
STMTS_WRAPPER({
float4 fd;
CastFunctor<__nv_bfloat16, float> cast0;
CastFunctor<dtype::bfloat164, float4> cast1;
TernaryOpFunctor<float, float4, float4, TernaryOpType::kFma> fma;
fd = fma(cast0(a), cast1(b), c);
return fd;
}))
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
dtype::bfloat168, dtype::bfloat168, dtype::float8_, TernaryOpType::kFma,
dtype::bfloat168, dtype::bfloat168, dtype::float8, TernaryOpType::kFma,
DEVICE, STMTS_WRAPPER({
dtype::float8_ fd;
dtype::float8 fd;
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
TernaryOpType::kFma>
fma;
@ -197,9 +190,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
return fd;
}))
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
__nv_bfloat16, dtype::bfloat168, dtype::float8_, TernaryOpType::kFma,
DEVICE, STMTS_WRAPPER({
dtype::float8_ fd;
__nv_bfloat16, dtype::bfloat168, dtype::float8, TernaryOpType::kFma, DEVICE,
STMTS_WRAPPER({
dtype::float8 fd;
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast;
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
TernaryOpType::kFma>

View File

@ -52,13 +52,7 @@ COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float2, float, UnaryOpType::kSum, DEVICE,
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float4, float, UnaryOpType::kSum, DEVICE,
{ return val.x + val.y + val.z + val.w; })
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(dtype::float4_, float, UnaryOpType::kSum,
DEVICE, {
return val.x.x + val.x.y + val.y.x +
val.y.y;
})
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(dtype::float8_, float, UnaryOpType::kSum,
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(dtype::float8, float, UnaryOpType::kSum,
DEVICE, {
return val.x.x + val.x.y + val.y.x +
val.y.y + val.z.x + val.z.y +

View File

@ -283,11 +283,14 @@ void rms_layernorm(
case 4:
RMSNORM_LAUNCHER(4, block);
break;
case 5:
RMSNORM_LAUNCHER(5, block);
break;
case 8:
RMSNORM_LAUNCHER(8, block);
break;
default:
AT_ERROR("unroll_factor must be 1, 2, 3, 4 or 8");
AT_ERROR("unroll_factor must be 1, 2, 3, 4, 5 or 8");
}
}
}
@ -330,11 +333,14 @@ void fused_add_rms_layernorm(
case 4:
FUSED_ADD_RMSNORM_LAUNCHER(4, block);
break;
case 5:
FUSED_ADD_RMSNORM_LAUNCHER(5, block);
break;
case 8:
FUSED_ADD_RMSNORM_LAUNCHER(8, block);
break;
default:
AT_ERROR("unroll_factor must be 1, 2, 3, 4 or 8");
AT_ERROR("unroll_factor must be 1, 2, 3, 4, 5 or 8");
}
}
}