diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 43dd153af..bba4bd070 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -225,7 +225,13 @@ class FusedLayerNorm(BaseLayerNorm): # fall back to the normal fused layernorm is not built ApexFusedLayerNorm = FusedLayerNormWithHook else: - ApexFusedLayerNorm = FusedLayerNormWithHook + try: + ApexFusedLayerNorm = FusedLayerNormWithHook + except NameError: + warnings.warn( + "Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead." + ) + return module layernorm = ( ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 597dd9c26..e20b8e239 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -120,7 +120,15 @@ class ShardConfig: Turn on all optimization. """ # you can add all the optimization flag here - self.enable_fused_normalization = True + try: + from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm # noqa + + apex_avail = True + except ImportError: + apex_avail = False + warnings.warn("You set enable_all_optimization=True, but apex is not installed.") + + self.enable_fused_normalization = apex_avail self.enable_flash_attention = True self.enable_jit_fused = True # This can cause non-in-place param sharding when used without ZeRO.