From b4d2377d4c482960af21bb77bf5ff78099865b02 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 9 Aug 2024 18:17:09 +0800 Subject: [PATCH] [Hotfix] Avoid fused RMSnorm import error without apex (#5985) Co-authored-by: Edenzzzz --- colossalai/shardformer/layer/normalization.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 59e1da9fc..043bf6aeb 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -42,7 +42,7 @@ try: return output except ImportError: - warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel") + warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel") FAST_LAYERNORM_SUPPORTED_SIZE = [ 1024, @@ -270,12 +270,6 @@ class FusedRMSNorm(BaseLayerNorm): Returns: nn.Module: FusedRMSNorm module. """ - try: - pass - except ImportError: - raise ImportError( - "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel" - ) LazyInitContext.materialize(module) @@ -284,11 +278,18 @@ class FusedRMSNorm(BaseLayerNorm): eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps elementwise_affine = getattr(module, "elementwise_affine", True) - rmsnorm = FusedRMSNormWithHook( - normalized_shape=normalized_shape, - eps=eps, - elementwise_affine=elementwise_affine, - ) + try: + rmsnorm = FusedRMSNormWithHook( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + ) + except ImportError: + warnings.warn( + "Module replacement failed.\ + Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel" + ) + return module rmsnorm.weight = module.weight