[Inference] Fix quant bits order (#5681)

pull/5679/head
傅剑寒 2024-04-30 19:38:00 +08:00 committed by GitHub
parent f79963199c
commit 9df016fc45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 4 deletions

View File

@ -390,7 +390,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
static_cast<uint16_t>(CastFunctor<float, uint8_t>()(val.x)); static_cast<uint16_t>(CastFunctor<float, uint8_t>()(val.x));
uint16_t tmp2 = uint16_t tmp2 =
static_cast<uint16_t>(CastFunctor<float, uint8_t>()(val.y)); static_cast<uint16_t>(CastFunctor<float, uint8_t>()(val.y));
uint16_t res = (tmp1 << 8U) | tmp2; uint16_t res = (tmp2 << 8U) | tmp1;
return res; return res;
})) }))
@ -401,8 +401,8 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint32_t, DEVICE, STMTS_WRAPPER({
b = CastFunctor<float, uint8_t>()(val.y); b = CastFunctor<float, uint8_t>()(val.y);
c = CastFunctor<float, uint8_t>()(val.z); c = CastFunctor<float, uint8_t>()(val.z);
d = CastFunctor<float, uint8_t>()(val.w); d = CastFunctor<float, uint8_t>()(val.w);
return (a << 24U) | (b << 16U) | return (d << 24U) | (c << 16U) |
(c << 8U) | d; (b << 8U) | a;
})) }))
// fp8x4 -> float4_ // fp8x4 -> float4_
@ -458,7 +458,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
static_cast<uint16_t>(CastFunctor<__nv_bfloat16, uint8_t>()(val.x)); static_cast<uint16_t>(CastFunctor<__nv_bfloat16, uint8_t>()(val.x));
uint16_t b = uint16_t b =
static_cast<uint16_t>(CastFunctor<__nv_bfloat16, uint8_t>()(val.y)); static_cast<uint16_t>(CastFunctor<__nv_bfloat16, uint8_t>()(val.y));
return (a << 8U) | b; return (b << 8U) | a;
})) }))
// bf164 -> fp8x4 // bf164 -> fp8x4