diff --git a/tests/test_analyzer/test_fx/test_bias_addition.py b/tests/test_analyzer/test_fx/test_bias_addition.py index 044a464be..61951e9a5 100644 --- a/tests/test_analyzer/test_fx/test_bias_addition.py +++ b/tests/test_analyzer/test_fx/test_bias_addition.py @@ -3,6 +3,8 @@ import torch from packaging import version from torch.utils.checkpoint import checkpoint +from colossalai.testing.utils import parameterize + try: from colossalai._analyzer.fx import symbolic_trace except: @@ -56,9 +58,13 @@ class SiuModel(torch.nn.Module): self.linear = LinearModel(3, 3, bias) self.conv = ConvModel(3, 6, 3, bias) - def forward(self, x, select=0): + def forward(self, x, select=torch.Tensor([0])): x = self.linear(x) - x = checkpoint(self.conv, x, select) + if select: + x = checkpoint(self.conv, x, 0) + else: + x = checkpoint(self.conv, x, 1) + return x @@ -75,10 +81,10 @@ class AddmmModel(torch.nn.Module): @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') -@pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("bias_addition_split", [True, False]) -@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)]) -@pytest.mark.parametrize("select", [0, 1]) +@parameterize("bias", [True, False]) +@parameterize("bias_addition_split", [True, False]) +@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)]) +@parameterize("select", [torch.Tensor([0]), torch.Tensor([1])]) def test_siu_model(bias, bias_addition_split, shape, select): model = SiuModel(bias=bias) x = torch.rand(shape) @@ -87,18 +93,18 @@ def test_siu_model(bias, bias_addition_split, shape, select): concrete_args={'select': select}, trace_act_ckpt=True, bias_addition_split=bias_addition_split) - assert torch.allclose(model(x, select), gm(x, select)), 'original model and traced model should be the same!' + assert torch.allclose(model(x, select), gm(x)), 'original model and traced model should be the same!' if bias and bias_addition_split: assert '+' in gm.code, 'bias addition should be split!' else: assert '+' not in gm.code, 'bias addition should not be split!' -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') -@pytest.mark.parametrize("alpha", [1, 2]) -@pytest.mark.parametrize("beta", [1, 2]) -@pytest.mark.parametrize("bias_addition_split", [True, False]) -@pytest.mark.parametrize("shape", [(3, 3), (5, 5)]) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@parameterize("alpha", [1, 2]) +@parameterize("beta", [1, 2]) +@parameterize("bias_addition_split", [True, False]) +@parameterize("shape", [(3, 3), (5, 5)]) def test_addmm_model(alpha, beta, bias_addition_split, shape): model = AddmmModel(alpha=alpha, beta=beta) x = torch.rand(shape) @@ -111,4 +117,5 @@ def test_addmm_model(alpha, beta, bias_addition_split, shape): if __name__ == '__main__': - test_siu_model(True, True, (3, 3, 3)) + test_siu_model() + test_addmm_model() diff --git a/tests/test_analyzer/test_fx/test_shape_prop.py b/tests/test_analyzer/test_fx/test_shape_prop.py index b19884a70..08f4ff2cb 100644 --- a/tests/test_analyzer/test_fx/test_shape_prop.py +++ b/tests/test_analyzer/test_fx/test_shape_prop.py @@ -1,16 +1,17 @@ import pytest -import timm.models as tmm import torch import torchvision.models as tm -from .zoo import tm_models, tmm_models +from packaging import version + +from colossalai.testing.utils import parameterize +from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models try: from colossalai._analyzer._subclasses import MetaTensorMode from colossalai._analyzer.fx import symbolic_trace from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass from colossalai._analyzer.fx.symbolic_profile import register_shape_impl - - + @register_shape_impl(torch.nn.functional.linear) def linear_impl(*args, **kwargs): assert True @@ -23,15 +24,15 @@ def _check_gm_validity(gm: torch.fx.GraphModule): for node in gm.graph.nodes: assert node.meta['info'].outputs, f'In {gm.__class__.__name__}, {node} has no output shape.' if node.op in [ - # 'call_module', # can apply to params - # 'call_function', # can apply to params - # 'call_method', # can apply to params + 'call_module', # can apply to params + 'call_function', # can apply to params + 'call_method', # can apply to params ]: - assert node.meta['info'].inputs, f'In {gm.__class__.__name__}, {node} has no input shape.' + assert hasattr(node.meta['info'], 'inputs'), f'In {gm.__class__.__name__}, {node} has no input shape.' -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') -@pytest.mark.parametrize('m', tm_models) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@parameterize('m', tm_models) def test_torchvision_shape_prop(m): with MetaTensorMode(): model = m() @@ -44,8 +45,8 @@ def test_torchvision_shape_prop(m): _check_gm_validity(gm) -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') -@pytest.mark.parametrize('m', tmm_models) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@parameterize('m', tmm_models) def test_timm_shape_prop(m): with MetaTensorMode(): model = m() @@ -53,11 +54,12 @@ def test_timm_shape_prop(m): meta_args = { "x": data, } + gm = symbolic_trace(model, meta_args=meta_args) shape_prop_pass(gm, data) _check_gm_validity(gm) if __name__ == "__main__": - test_torchvision_shape_prop(tm.resnet18) - test_timm_shape_prop(tmm.vgg11) + test_torchvision_shape_prop() + test_timm_shape_prop() diff --git a/tests/test_analyzer/test_fx/test_symbolic_profile.py b/tests/test_analyzer/test_fx/test_symbolic_profile.py index 5f749e6f3..be781599f 100644 --- a/tests/test_analyzer/test_fx/test_symbolic_profile.py +++ b/tests/test_analyzer/test_fx/test_symbolic_profile.py @@ -1,8 +1,10 @@ import pytest -import timm.models as tmm import torch import torchvision.models as tm -from .zoo import tm_models, tmm_models +from packaging import version + +from colossalai.testing.utils import parameterize +from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models try: from colossalai._analyzer._subclasses import MetaTensorMode @@ -16,8 +18,8 @@ def _check_gm_validity(gm: torch.fx.GraphModule): assert len(node.meta['info'].global_ctx), f'In {gm.__class__.__name__}, {node} has empty global context.' -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') -@pytest.mark.parametrize('m', tm_models) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@parameterize('m', tm_models) def test_torchvision_profile(m, verbose=False, bias_addition_split=False): with MetaTensorMode(): model = m() @@ -30,8 +32,8 @@ def test_torchvision_profile(m, verbose=False, bias_addition_split=False): _check_gm_validity(gm) -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') -@pytest.mark.parametrize('m', tmm_models) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@parameterize('m', tmm_models) def test_timm_profile(m, verbose=False, bias_addition_split=False): with MetaTensorMode(): model = m() @@ -45,5 +47,5 @@ def test_timm_profile(m, verbose=False, bias_addition_split=False): if __name__ == "__main__": - test_torchvision_profile(tm.vit_b_16, verbose=True, bias_addition_split=False) - test_timm_profile(tmm.gmlp_b16_224, verbose=True, bias_addition_split=False) + test_torchvision_profile() + test_timm_profile() diff --git a/tests/test_analyzer/test_fx/zoo.py b/tests/test_analyzer/test_fx/zoo.py index 925078d0d..a96aa3949 100644 --- a/tests/test_analyzer/test_fx/zoo.py +++ b/tests/test_analyzer/test_fx/zoo.py @@ -33,18 +33,18 @@ tmm_models = [ tmm.dm_nfnet_f0, tmm.eca_nfnet_l0, tmm.efficientformer_l1, - tmm.ese_vovnet19b_dw, + # tmm.ese_vovnet19b_dw, tmm.gmixer_12_224, tmm.gmlp_b16_224, - tmm.hardcorenas_a, + # tmm.hardcorenas_a, tmm.hrnet_w18_small, tmm.inception_v3, tmm.mixer_b16_224, tmm.nf_ecaresnet101, tmm.nf_regnet_b0, # tmm.pit_b_224, # pretrained only - tmm.regnetv_040, - tmm.skresnet18, + # tmm.regnetv_040, + # tmm.skresnet18, # tmm.swin_base_patch4_window7_224, # fx bad case # tmm.tnt_b_patch16_224, # bad case tmm.vgg11, diff --git a/tests/test_analyzer/test_subclasses/test_flop_tensor.py b/tests/test_analyzer/test_subclasses/test_flop_tensor.py index 551628103..752836141 100644 --- a/tests/test_analyzer/test_subclasses/test_flop_tensor.py +++ b/tests/test_analyzer/test_subclasses/test_flop_tensor.py @@ -1,9 +1,10 @@ import pytest import torch -import torch.nn as nn import torch.nn.functional as F import torchvision.models as tm -from .zoo import tm_models, tmm_models +from packaging import version + +from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models try: from colossalai._analyzer._subclasses import MetaTensorMode, flop_count @@ -11,7 +12,7 @@ except: pass -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.parametrize('m', tm_models + tmm_models) def test_flop_count_module(m): x = torch.rand(2, 3, 224, 224) @@ -37,7 +38,7 @@ odd_cases = [ ] -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.parametrize('func, args, kwargs', odd_cases) def test_flop_count_function(func, args, kwargs): rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True) @@ -46,5 +47,5 @@ def test_flop_count_function(func, args, kwargs): if __name__ == '__main__': - test_flop_count_module(tm.resnet18, torch.rand(2, 3, 224, 224)) + test_flop_count_module(tm.resnet18) test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {'inplace': True}) diff --git a/tests/test_analyzer/test_subclasses/test_meta_mode.py b/tests/test_analyzer/test_subclasses/test_meta_mode.py index d8122b019..160d411f6 100644 --- a/tests/test_analyzer/test_subclasses/test_meta_mode.py +++ b/tests/test_analyzer/test_subclasses/test_meta_mode.py @@ -1,12 +1,13 @@ import pytest import torch -import torch.distributed as dist import torchvision.models as tm +from packaging import version + try: from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode except: pass -from .zoo import tm_models, tmm_models +from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor): @@ -28,7 +29,7 @@ def run_and_compare(model): compare_all(x.grad, meta_x.grad) -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.parametrize('m', tm_models + tmm_models) def test_meta_mode_shape(m): run_and_compare(m()) diff --git a/tests/test_analyzer/test_subclasses/zoo.py b/tests/test_analyzer/test_subclasses/zoo.py deleted file mode 100644 index 925078d0d..000000000 --- a/tests/test_analyzer/test_subclasses/zoo.py +++ /dev/null @@ -1,53 +0,0 @@ -import timm.models as tmm -import torchvision.models as tm - -# input shape: (batch_size, 3, 224, 224) -tm_models = [ - tm.alexnet, - tm.convnext_base, - tm.densenet121, - # tm.efficientnet_v2_s, - # tm.googlenet, # output bad case - # tm.inception_v3, # bad case - tm.mobilenet_v2, - tm.mobilenet_v3_small, - tm.mnasnet0_5, - tm.resnet18, - tm.regnet_x_16gf, - tm.resnext50_32x4d, - tm.shufflenet_v2_x0_5, - tm.squeezenet1_0, - # tm.swin_s, # fx bad case - tm.vgg11, - tm.vit_b_16, - tm.wide_resnet50_2, -] - -tmm_models = [ - tmm.beit_base_patch16_224, - tmm.beitv2_base_patch16_224, - tmm.cait_s24_224, - tmm.coat_lite_mini, - tmm.convit_base, - tmm.deit3_base_patch16_224, - tmm.dm_nfnet_f0, - tmm.eca_nfnet_l0, - tmm.efficientformer_l1, - tmm.ese_vovnet19b_dw, - tmm.gmixer_12_224, - tmm.gmlp_b16_224, - tmm.hardcorenas_a, - tmm.hrnet_w18_small, - tmm.inception_v3, - tmm.mixer_b16_224, - tmm.nf_ecaresnet101, - tmm.nf_regnet_b0, - # tmm.pit_b_224, # pretrained only - tmm.regnetv_040, - tmm.skresnet18, - # tmm.swin_base_patch4_window7_224, # fx bad case - # tmm.tnt_b_patch16_224, # bad case - tmm.vgg11, - tmm.vit_base_patch16_18x2_224, - tmm.wide_resnet50_2, -]