[hotfix] Fixed fused layernorm bug without apex (#5609)

* fixed fused layernorm bug without apex

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* same for flash attn

* remove flash attn check

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5627/head^2
Edenzzzz 7 months ago committed by GitHub
parent 0d0a582033
commit 7ee569b05f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -225,7 +225,13 @@ class FusedLayerNorm(BaseLayerNorm):
# fall back to the normal fused layernorm is not built # fall back to the normal fused layernorm is not built
ApexFusedLayerNorm = FusedLayerNormWithHook ApexFusedLayerNorm = FusedLayerNormWithHook
else: else:
ApexFusedLayerNorm = FusedLayerNormWithHook try:
ApexFusedLayerNorm = FusedLayerNormWithHook
except NameError:
warnings.warn(
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead."
)
return module
layernorm = ( layernorm = (
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)

@ -120,7 +120,15 @@ class ShardConfig:
Turn on all optimization. Turn on all optimization.
""" """
# you can add all the optimization flag here # you can add all the optimization flag here
self.enable_fused_normalization = True try:
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm # noqa
apex_avail = True
except ImportError:
apex_avail = False
warnings.warn("You set enable_all_optimization=True, but apex is not installed.")
self.enable_fused_normalization = apex_avail
self.enable_flash_attention = True self.enable_flash_attention = True
self.enable_jit_fused = True self.enable_jit_fused = True
# This can cause non-in-place param sharding when used without ZeRO. # This can cause non-in-place param sharding when used without ZeRO.

Loading…
Cancel
Save