mirror of https://github.com/hpcaitech/ColossalAI
[Hotfix] hotfix normalization (#6163)
* [fix] hotfix normalization * [hotfix] force doc ci test * [hotfix] fallback docpull/6165/head
parent
130229fdcb
commit
fa9d0318e4
|
@ -76,18 +76,24 @@ if SUPPORT_NPU:
|
||||||
|
|
||||||
FusedRMSNormWithHook = NPUFusedRMSNormWithHook
|
FusedRMSNormWithHook = NPUFusedRMSNormWithHook
|
||||||
else:
|
else:
|
||||||
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
try:
|
||||||
|
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
||||||
|
|
||||||
class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
|
class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
|
||||||
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
||||||
super().__init__(normalized_shape, eps, elementwise_affine)
|
super().__init__(normalized_shape, eps, elementwise_affine)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
output = super().forward(input)
|
output = super().forward(input)
|
||||||
output = hook_parameter_in_backward(output, self.weight)
|
output = hook_parameter_in_backward(output, self.weight)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
|
||||||
|
except ImportError:
|
||||||
|
warnings.warn(
|
||||||
|
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel"
|
||||||
|
)
|
||||||
|
|
||||||
FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
|
|
||||||
|
|
||||||
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||||
1024,
|
1024,
|
||||||
|
|
Loading…
Reference in New Issue