mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] Fixed fused layernorm bug without apex (#5609)
* fixed fused layernorm bug without apex * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * same for flash attn * remove flash attn check --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5627/head^2
parent
0d0a582033
commit
7ee569b05f
|
@ -225,7 +225,13 @@ class FusedLayerNorm(BaseLayerNorm):
|
||||||
# fall back to the normal fused layernorm is not built
|
# fall back to the normal fused layernorm is not built
|
||||||
ApexFusedLayerNorm = FusedLayerNormWithHook
|
ApexFusedLayerNorm = FusedLayerNormWithHook
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
ApexFusedLayerNorm = FusedLayerNormWithHook
|
ApexFusedLayerNorm = FusedLayerNormWithHook
|
||||||
|
except NameError:
|
||||||
|
warnings.warn(
|
||||||
|
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead."
|
||||||
|
)
|
||||||
|
return module
|
||||||
|
|
||||||
layernorm = (
|
layernorm = (
|
||||||
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
|
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
|
||||||
|
|
|
@ -120,7 +120,15 @@ class ShardConfig:
|
||||||
Turn on all optimization.
|
Turn on all optimization.
|
||||||
"""
|
"""
|
||||||
# you can add all the optimization flag here
|
# you can add all the optimization flag here
|
||||||
self.enable_fused_normalization = True
|
try:
|
||||||
|
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm # noqa
|
||||||
|
|
||||||
|
apex_avail = True
|
||||||
|
except ImportError:
|
||||||
|
apex_avail = False
|
||||||
|
warnings.warn("You set enable_all_optimization=True, but apex is not installed.")
|
||||||
|
|
||||||
|
self.enable_fused_normalization = apex_avail
|
||||||
self.enable_flash_attention = True
|
self.enable_flash_attention = True
|
||||||
self.enable_jit_fused = True
|
self.enable_jit_fused = True
|
||||||
# This can cause non-in-place param sharding when used without ZeRO.
|
# This can cause non-in-place param sharding when used without ZeRO.
|
||||||
|
|
Loading…
Reference in New Issue