2022-07-12 07:01:58 +00:00
|
|
|
import torch
|
2022-11-01 14:53:51 +00:00
|
|
|
|
|
|
|
from ...registry import meta_patched_module
|
2022-07-12 07:01:58 +00:00
|
|
|
|
|
|
|
|
|
|
|
@meta_patched_module.register(torch.nn.LayerNorm)
|
|
|
|
@meta_patched_module.register(torch.nn.GroupNorm)
|
|
|
|
@meta_patched_module.register(torch.nn.BatchNorm1d)
|
|
|
|
@meta_patched_module.register(torch.nn.BatchNorm2d)
|
|
|
|
@meta_patched_module.register(torch.nn.BatchNorm3d)
|
|
|
|
def torch_nn_normalize(self, input):
|
|
|
|
# check shape
|
|
|
|
if isinstance(self, torch.nn.BatchNorm1d):
|
|
|
|
assert input.dim() in [2, 3]
|
|
|
|
elif isinstance(self, torch.nn.BatchNorm2d):
|
|
|
|
assert input.dim() == 4
|
|
|
|
elif isinstance(self, torch.nn.BatchNorm3d):
|
|
|
|
assert input.dim() == 5
|
|
|
|
|
|
|
|
# normalization maintain the same shape as the input
|
2022-07-14 06:24:13 +00:00
|
|
|
return input.clone()
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
import apex
|
2023-09-19 06:20:26 +00:00
|
|
|
|
2022-07-14 06:24:13 +00:00
|
|
|
meta_patched_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
|
|
|
|
meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
|
|
|
|
meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
|
|
|
|
meta_patched_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize)
|
2022-07-21 07:29:11 +00:00
|
|
|
except (ImportError, AttributeError):
|
2022-07-14 06:24:13 +00:00
|
|
|
pass
|