ColossalAI/colossalai/fx/tracer/meta_patch/patched_module/normalization.py

33 lines
1.1 KiB
Python
Raw Normal View History

import torch
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.LayerNorm)
@meta_patched_module.register(torch.nn.GroupNorm)
@meta_patched_module.register(torch.nn.BatchNorm1d)
@meta_patched_module.register(torch.nn.BatchNorm2d)
@meta_patched_module.register(torch.nn.BatchNorm3d)
def torch_nn_normalize(self, input):
# check shape
if isinstance(self, torch.nn.BatchNorm1d):
assert input.dim() in [2, 3]
elif isinstance(self, torch.nn.BatchNorm2d):
assert input.dim() == 4
elif isinstance(self, torch.nn.BatchNorm3d):
assert input.dim() == 5
# normalization maintain the same shape as the input
return input.clone()
try:
import apex
meta_patched_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize)
except (ImportError, AttributeError):
pass