[Hotfix] hotfix normalization (#6163)

* [fix] hotfix normalization

* [hotfix] force doc ci test

* [hotfix] fallback doc
pull/6165/head
duanjunwen 2024-12-23 16:29:48 +08:00 committed by GitHub
parent 130229fdcb
commit fa9d0318e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 15 additions and 9 deletions

View File

@ -76,18 +76,24 @@ if SUPPORT_NPU:
FusedRMSNormWithHook = NPUFusedRMSNormWithHook
else:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
try:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_parameter_in_backward(output, self.weight)
return output
def forward(self, input):
output = super().forward(input)
output = hook_parameter_in_backward(output, self.weight)
return output
FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
except ImportError:
warnings.warn(
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel"
)
FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
FAST_LAYERNORM_SUPPORTED_SIZE = [
1024,