Browse Source

[fx] fixed apex normalization patch exception (#1352)

pull/1353/head
Frank Lee 2 years ago committed by GitHub
parent
commit
274c1a3b5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      colossalai/fx/tracer/meta_patch/patched_module/normalization.py

2
colossalai/fx/tracer/meta_patch/patched_module/normalization.py

@ -26,5 +26,5 @@ try:
meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize)
except ImportError:
except (ImportError, AttributeError):
pass

Loading…
Cancel
Save