[Hotfix] hotfix normalization (#6163)

* [fix] hotfix normalization

* [hotfix] force doc ci test

* [hotfix] fallback doc
pull/6165/head
duanjunwen 2024-12-23 16:29:48 +08:00 committed by GitHub
parent 130229fdcb
commit fa9d0318e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 15 additions and 9 deletions

View File

@ -76,6 +76,7 @@ if SUPPORT_NPU:
FusedRMSNormWithHook = NPUFusedRMSNormWithHook
else:
try:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
@ -88,6 +89,11 @@ else:
return output
FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
except ImportError:
warnings.warn(
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel"
)
FAST_LAYERNORM_SUPPORTED_SIZE = [
1024,