[Hotfix] Avoid fused RMSnorm import error without apex (#5985)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
pull/5991/head
Edenzzzz 2024-08-09 18:17:09 +08:00 committed by GitHub
parent ad3fa4f49c
commit b4d2377d4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 13 additions and 12 deletions

View File

@ -42,7 +42,7 @@ try:
return output return output
except ImportError: except ImportError:
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel") warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel")
FAST_LAYERNORM_SUPPORTED_SIZE = [ FAST_LAYERNORM_SUPPORTED_SIZE = [
1024, 1024,
@ -270,12 +270,6 @@ class FusedRMSNorm(BaseLayerNorm):
Returns: Returns:
nn.Module: FusedRMSNorm module. nn.Module: FusedRMSNorm module.
""" """
try:
pass
except ImportError:
raise ImportError(
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
)
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
@ -284,11 +278,18 @@ class FusedRMSNorm(BaseLayerNorm):
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
elementwise_affine = getattr(module, "elementwise_affine", True) elementwise_affine = getattr(module, "elementwise_affine", True)
try:
rmsnorm = FusedRMSNormWithHook( rmsnorm = FusedRMSNormWithHook(
normalized_shape=normalized_shape, normalized_shape=normalized_shape,
eps=eps, eps=eps,
elementwise_affine=elementwise_affine, elementwise_affine=elementwise_affine,
) )
except ImportError:
warnings.warn(
"Module replacement failed.\
Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
)
return module
rmsnorm.weight = module.weight rmsnorm.weight = module.weight