From fa9d0318e4ba510e6abc022662b461c8334dcee2 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 23 Dec 2024 16:29:48 +0800 Subject: [PATCH] [Hotfix] hotfix normalization (#6163) * [fix] hotfix normalization * [hotfix] force doc ci test * [hotfix] fallback doc --- colossalai/shardformer/layer/normalization.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 1c2b44fc8..2edca3b3e 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -76,18 +76,24 @@ if SUPPORT_NPU: FusedRMSNormWithHook = NPUFusedRMSNormWithHook else: - from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm + try: + from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm - class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm): - def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): - super().__init__(normalized_shape, eps, elementwise_affine) + class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm): + def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): + super().__init__(normalized_shape, eps, elementwise_affine) - def forward(self, input): - output = super().forward(input) - output = hook_parameter_in_backward(output, self.weight) - return output + def forward(self, input): + output = super().forward(input) + output = hook_parameter_in_backward(output, self.weight) + return output + + FusedRMSNormWithHook = CUDAFusedRMSNormWithHook + except ImportError: + warnings.warn( + "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel" + ) - FusedRMSNormWithHook = CUDAFusedRMSNormWithHook FAST_LAYERNORM_SUPPORTED_SIZE = [ 1024,