[Inference] Fix quant bits order (#5681)

pull/5679/head
傅剑寒 2024-04-30 19:38:00 +08:00 committed by GitHub
parent f79963199c
commit 9df016fc45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 4 deletions

View File

@ -390,7 +390,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
static_cast<uint16_t>(CastFunctor<float, uint8_t>()(val.x));
uint16_t tmp2 =
static_cast<uint16_t>(CastFunctor<float, uint8_t>()(val.y));
uint16_t res = (tmp1 << 8U) | tmp2;
uint16_t res = (tmp2 << 8U) | tmp1;
return res;
}))
@ -401,8 +401,8 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint32_t, DEVICE, STMTS_WRAPPER({
b = CastFunctor<float, uint8_t>()(val.y);
c = CastFunctor<float, uint8_t>()(val.z);
d = CastFunctor<float, uint8_t>()(val.w);
return (a << 24U) | (b << 16U) |
(c << 8U) | d;
return (d << 24U) | (c << 16U) |
(b << 8U) | a;
}))
// fp8x4 -> float4_
@ -458,7 +458,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
static_cast<uint16_t>(CastFunctor<__nv_bfloat16, uint8_t>()(val.x));
uint16_t b =
static_cast<uint16_t>(CastFunctor<__nv_bfloat16, uint8_t>()(val.y));
return (a << 8U) | b;
return (b << 8U) | a;
}))
// bf164 -> fp8x4