diff --git a/tests/test_fx/test_meta/test_aten.py b/tests/test_fx/test_meta/test_aten.py new file mode 100644 index 000000000..991130376 --- /dev/null +++ b/tests/test_fx/test_meta/test_aten.py @@ -0,0 +1,88 @@ +from typing import Any, Callable, Union +import torch +import torch.nn as nn +import torch.nn.functional as F +from colossalai.fx.profiler import MetaTensor + +import pytest + +try: + meta_lib = torch.library.Library("aten", "IMPL", "Meta") + INCOMPATIBLE = False # version > 1.12.0 +except: + INCOMPATIBLE = True + +aten = torch.ops.aten + +registered_meta = { + ('aten.convolution.default', True): [ # (aten ops, requires_backward) + (nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), + (nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4)), + (nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4, 4)), + (nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), + (nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, + dilation=2), torch.rand(2, 3, 4, 4)), + (nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, + dilation=2), torch.rand(2, 3, 4, 4, 4)), + ], + ('aten.native_batch_norm.default', True): [ + (nn.BatchNorm1d(4), torch.rand(2, 4)), + (nn.BatchNorm2d(4), torch.rand(1, 4, 4, 4)), + (nn.BatchNorm3d(4), torch.rand(1, 4, 4, 4, 4)), + ], + ('aten.native_layer_norm.default', True): [(nn.LayerNorm(4), torch.rand(1, 2, 3, 4)),], + ('aten.avg_pool1d.default', True): [ + (nn.MaxPool1d(3, stride=2), torch.rand(4, 5, 5)), + (nn.AvgPool1d(3, stride=2), torch.rand(4, 5, 5)), + (nn.AdaptiveMaxPool1d(3), torch.rand(4, 5, 5)), + (nn.AdaptiveAvgPool1d(3), torch.rand(4, 5, 5)), + ], + ('aten.avg_pool2d.default', True): [ + (nn.MaxPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), + (nn.AvgPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), + (nn.AdaptiveMaxPool2d((3, 2)), torch.rand(2, 4, 5, 5)), + (nn.AdaptiveAvgPool2d((3, 2)), torch.rand(2, 4, 5, 5)), + ], + ('aten.relu.default', True): [ + (nn.ReLU(), torch.rand(4, 3, 1, 2)), + (nn.LeakyReLU(), torch.rand(4, 3, 1, 2)), + (nn.SiLU(), torch.rand(4, 3, 1, 2)), + (nn.GELU(), torch.rand(4, 3, 1, 2)), + (nn.ELU(), torch.rand(4, 3, 1, 2)), + (nn.Sigmoid(), torch.rand(4, 3, 1, 2)), + (nn.Tanh(), torch.rand(4, 3, 1, 2)), + (nn.Hardswish(), torch.rand(4, 3, 1, 2)), + ] +} + + +def compare_all(tensor: torch.Tensor, meta_tensor: MetaTensor) -> Any: + assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.' + assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.' + assert tensor.stride() == meta_tensor.stride( + ), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.' + + +def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any: + x.requires_grad = requires_backward + meta_x = MetaTensor(x.to('meta')) + if isinstance(f, nn.Module): + x_out, meta_out = f(x), f.to('meta')(meta_x) + else: + x_out, meta_out = f(x), f(meta_x) + compare_all(x_out, meta_out) + if requires_backward: + x_out.sum().backward() + meta_out.sum().backward() + compare_all(x.grad, meta_x.grad) + + +@pytest.mark.skipif(INCOMPATIBLE, reason='torch version is lower than 1.12.0') +def test_meta_aten(): + for (aten_op, requires_backward), v in registered_meta.items(): + for f, x in v: + run_and_compare(f, x, requires_backward) + + +if __name__ == '__main__': + test_meta_aten() diff --git a/tests/test_fx/test_meta/test_backward.py b/tests/test_fx/test_meta/test_backward.py new file mode 100644 index 000000000..98b3b464f --- /dev/null +++ b/tests/test_fx/test_meta/test_backward.py @@ -0,0 +1,63 @@ +import torchvision.models as tm +import timm.models as tmm +import torch +from colossalai.fx.profiler import MetaTensor + +import pytest + +try: + meta_lib = torch.library.Library("aten", "IMPL", "Meta") + incompatible = False # version > 1.12.0 +except: + incompatible = True + + +tm_models = [ + tm.vgg11, + tm.resnet18, + tm.densenet121, + tm.mobilenet_v3_small, + tm.resnext50_32x4d, + tm.wide_resnet50_2, + tm.regnet_x_16gf, + tm.mnasnet0_5, + tm.efficientnet_b0, +] + + +tmm_models = [ + tmm.resnest.resnest50d, + tmm.beit.beit_base_patch16_224, + tmm.cait.cait_s24_224, + tmm.efficientnet.efficientnetv2_m, + tmm.resmlp_12_224, + tmm.vision_transformer.vit_base_patch16_224, + tmm.deit_base_distilled_patch16_224, + tmm.convnext.convnext_base, + tmm.vgg.vgg11, + tmm.dpn.dpn68, + tmm.densenet.densenet121, + tmm.rexnet.rexnet_100, + tmm.swin_transformer.swin_base_patch4_window7_224 +] + + +@pytest.mark.skipif(incompatible, reason='torch version is lower than 1.12.0') +def test_torchvision_models(): + for m in tm_models: + model = m().to('meta') + data = torch.rand(1000, 3, 224, 224, device='meta') + model(MetaTensor(data)).sum().backward() + + +@pytest.mark.skipif(incompatible, reason='torch version is lower than 1.12.0') +def test_timm_models(): + for m in tmm_models: + model = m().to('meta') + data = torch.rand(1000, 3, 224, 224, device='meta') + model(MetaTensor(data)).sum().backward() + + +if __name__ == '__main__': + test_torchvision_models() + test_timm_models()