mirror of https://github.com/hpcaitech/ColossalAI
[Hotfix] Avoid fused RMSnorm import error without apex (#5985)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>pull/5991/head
parent
ad3fa4f49c
commit
b4d2377d4c
|
@ -42,7 +42,7 @@ try:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel")
|
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel")
|
||||||
|
|
||||||
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||||
1024,
|
1024,
|
||||||
|
@ -270,12 +270,6 @@ class FusedRMSNorm(BaseLayerNorm):
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: FusedRMSNorm module.
|
nn.Module: FusedRMSNorm module.
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
pass
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
|
|
||||||
)
|
|
||||||
|
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
|
|
||||||
|
@ -284,11 +278,18 @@ class FusedRMSNorm(BaseLayerNorm):
|
||||||
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
|
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
|
||||||
elementwise_affine = getattr(module, "elementwise_affine", True)
|
elementwise_affine = getattr(module, "elementwise_affine", True)
|
||||||
|
|
||||||
|
try:
|
||||||
rmsnorm = FusedRMSNormWithHook(
|
rmsnorm = FusedRMSNormWithHook(
|
||||||
normalized_shape=normalized_shape,
|
normalized_shape=normalized_shape,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
elementwise_affine=elementwise_affine,
|
elementwise_affine=elementwise_affine,
|
||||||
)
|
)
|
||||||
|
except ImportError:
|
||||||
|
warnings.warn(
|
||||||
|
"Module replacement failed.\
|
||||||
|
Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
|
||||||
|
)
|
||||||
|
return module
|
||||||
|
|
||||||
rmsnorm.weight = module.weight
|
rmsnorm.weight = module.weight
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue