diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 1c2b44fc8..2edca3b3e 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -76,18 +76,24 @@ if SUPPORT_NPU: FusedRMSNormWithHook = NPUFusedRMSNormWithHook else: - from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm + try: + from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm - class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm): - def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): - super().__init__(normalized_shape, eps, elementwise_affine) + class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm): + def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): + super().__init__(normalized_shape, eps, elementwise_affine) - def forward(self, input): - output = super().forward(input) - output = hook_parameter_in_backward(output, self.weight) - return output + def forward(self, input): + output = super().forward(input) + output = hook_parameter_in_backward(output, self.weight) + return output + + FusedRMSNormWithHook = CUDAFusedRMSNormWithHook + except ImportError: + warnings.warn( + "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel" + ) - FusedRMSNormWithHook = CUDAFusedRMSNormWithHook FAST_LAYERNORM_SUPPORTED_SIZE = [ 1024,