mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] Fixed fused layernorm bug without apex (#5609)
* fixed fused layernorm bug without apex * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * same for flash attn * remove flash attn check --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5627/head^2
parent
0d0a582033
commit
7ee569b05f
|
@ -225,7 +225,13 @@ class FusedLayerNorm(BaseLayerNorm):
|
|||
# fall back to the normal fused layernorm is not built
|
||||
ApexFusedLayerNorm = FusedLayerNormWithHook
|
||||
else:
|
||||
ApexFusedLayerNorm = FusedLayerNormWithHook
|
||||
try:
|
||||
ApexFusedLayerNorm = FusedLayerNormWithHook
|
||||
except NameError:
|
||||
warnings.warn(
|
||||
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead."
|
||||
)
|
||||
return module
|
||||
|
||||
layernorm = (
|
||||
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
|
||||
|
|
|
@ -120,7 +120,15 @@ class ShardConfig:
|
|||
Turn on all optimization.
|
||||
"""
|
||||
# you can add all the optimization flag here
|
||||
self.enable_fused_normalization = True
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm # noqa
|
||||
|
||||
apex_avail = True
|
||||
except ImportError:
|
||||
apex_avail = False
|
||||
warnings.warn("You set enable_all_optimization=True, but apex is not installed.")
|
||||
|
||||
self.enable_fused_normalization = apex_avail
|
||||
self.enable_flash_attention = True
|
||||
self.enable_jit_fused = True
|
||||
# This can cause non-in-place param sharding when used without ZeRO.
|
||||
|
|
Loading…
Reference in New Issue