mirror of https://github.com/hpcaitech/ColossalAI
[Inference] Fix quant bits order (#5681)
parent
f79963199c
commit
9df016fc45
|
@ -390,7 +390,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||||
static_cast<uint16_t>(CastFunctor<float, uint8_t>()(val.x));
|
static_cast<uint16_t>(CastFunctor<float, uint8_t>()(val.x));
|
||||||
uint16_t tmp2 =
|
uint16_t tmp2 =
|
||||||
static_cast<uint16_t>(CastFunctor<float, uint8_t>()(val.y));
|
static_cast<uint16_t>(CastFunctor<float, uint8_t>()(val.y));
|
||||||
uint16_t res = (tmp1 << 8U) | tmp2;
|
uint16_t res = (tmp2 << 8U) | tmp1;
|
||||||
return res;
|
return res;
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@ -401,8 +401,8 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint32_t, DEVICE, STMTS_WRAPPER({
|
||||||
b = CastFunctor<float, uint8_t>()(val.y);
|
b = CastFunctor<float, uint8_t>()(val.y);
|
||||||
c = CastFunctor<float, uint8_t>()(val.z);
|
c = CastFunctor<float, uint8_t>()(val.z);
|
||||||
d = CastFunctor<float, uint8_t>()(val.w);
|
d = CastFunctor<float, uint8_t>()(val.w);
|
||||||
return (a << 24U) | (b << 16U) |
|
return (d << 24U) | (c << 16U) |
|
||||||
(c << 8U) | d;
|
(b << 8U) | a;
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// fp8x4 -> float4_
|
// fp8x4 -> float4_
|
||||||
|
@ -458,7 +458,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||||
static_cast<uint16_t>(CastFunctor<__nv_bfloat16, uint8_t>()(val.x));
|
static_cast<uint16_t>(CastFunctor<__nv_bfloat16, uint8_t>()(val.x));
|
||||||
uint16_t b =
|
uint16_t b =
|
||||||
static_cast<uint16_t>(CastFunctor<__nv_bfloat16, uint8_t>()(val.y));
|
static_cast<uint16_t>(CastFunctor<__nv_bfloat16, uint8_t>()(val.y));
|
||||||
return (a << 8U) | b;
|
return (b << 8U) | a;
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// bf164 -> fp8x4
|
// bf164 -> fp8x4
|
||||||
|
|
Loading…
Reference in New Issue