From a674c6334846aa4af71703961d68907e8d0611b2 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 15 Mar 2023 10:42:07 +0800 Subject: [PATCH] [test] added torchvision models to test model zoo (#3132) * [test] added torchvision models to test model zoo * polish code * polish code * polish code * polish code * polish code * polish code --- tests/kit/model_zoo/__init__.py | 2 +- tests/kit/model_zoo/registry.py | 5 + tests/kit/model_zoo/torchvision/__init__.py | 1 + .../kit/model_zoo/torchvision/torchvision.py | 131 ++++++++++++++++++ .../test_torchvision_model.py | 49 ++++--- 5 files changed, 162 insertions(+), 26 deletions(-) create mode 100644 tests/kit/model_zoo/torchvision/__init__.py create mode 100644 tests/kit/model_zoo/torchvision/torchvision.py diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py index 6d77fb850..abe18ebfa 100644 --- a/tests/kit/model_zoo/__init__.py +++ b/tests/kit/model_zoo/__init__.py @@ -1,4 +1,4 @@ -from . import diffusers, timm +from . import diffusers, timm, torchvision from .registry import model_zoo __all__ = ['model_zoo'] diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py index 4e7dcb30f..7470327a6 100644 --- a/tests/kit/model_zoo/registry.py +++ b/tests/kit/model_zoo/registry.py @@ -9,8 +9,13 @@ __all__ = ['ModelZooRegistry', 'ModelAttributem', 'model_zoo'] class ModelAttribute: """ Attributes of a model. + + Args: + has_control_flow (bool): Whether the model contains branching in its forward method. + has_stochastic_depth_prob (bool): Whether the model contains stochastic depth probability. Often seen in the torchvision models. """ has_control_flow: bool = False + has_stochastic_depth_prob: bool = False class ModelZooRegistry(dict): diff --git a/tests/kit/model_zoo/torchvision/__init__.py b/tests/kit/model_zoo/torchvision/__init__.py new file mode 100644 index 000000000..55d58f97b --- /dev/null +++ b/tests/kit/model_zoo/torchvision/__init__.py @@ -0,0 +1 @@ +from .torchvision import * diff --git a/tests/kit/model_zoo/torchvision/torchvision.py b/tests/kit/model_zoo/torchvision/torchvision.py new file mode 100644 index 000000000..62bda93d5 --- /dev/null +++ b/tests/kit/model_zoo/torchvision/torchvision.py @@ -0,0 +1,131 @@ +from collections import namedtuple + +import torch +import torchvision +import torchvision.models as tm +from packaging import version + +from ..registry import ModelAttribute, model_zoo + +data_gen_fn = lambda: dict(x=torch.rand(4, 3, 224, 224)) +output_transform_fn = lambda x: dict(output=x) + +# special data gen fn +inception_v3_data_gen_fn = lambda: dict(x=torch.rand(4, 3, 299, 299)) + + +# special model fn +def swin_s(): + from torchvision.models.swin_transformer import Swin_T_Weights, _swin_transformer + + # adapted from torchvision.models.swin_transformer.swin_small + weights = None + weights = Swin_T_Weights.verify(weights) + progress = True + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=[7, 7], + stochastic_depth_prob=0, # it is originally 0.2, but we set it to 0 to make it deterministic + weights=weights, + progress=progress, + ) + + +# special output transform fn +google_net_output_transform_fn = lambda x: dict(output=x.logits) if isinstance(x, torchvision.models.GoogLeNetOutputs + ) else dict(output=x) +swin_s_output_output_transform_fn = lambda x: {f'output{idx}': val + for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x) +inception_v3_output_transform_fn = lambda x: dict(output=x.logits) if isinstance(x, torchvision.models.InceptionOutputs + ) else dict(output=x) + +model_zoo.register(name='torchvision_alexnet', + model_fn=tm.alexnet, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_densenet121', + model_fn=tm.densenet121, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_efficientnet_b0', + model_fn=tm.efficientnet_b0, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_stochastic_depth_prob=True)) +model_zoo.register(name='torchvision_googlenet', + model_fn=tm.googlenet, + data_gen_fn=data_gen_fn, + output_transform_fn=google_net_output_transform_fn) +model_zoo.register(name='torchvision_inception_v3', + model_fn=tm.inception_v3, + data_gen_fn=inception_v3_data_gen_fn, + output_transform_fn=inception_v3_output_transform_fn) +model_zoo.register(name='torchvision_mobilenet_v2', + model_fn=tm.mobilenet_v2, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_mobilenet_v3_small', + model_fn=tm.mobilenet_v3_small, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_mnasnet0_5', + model_fn=tm.mnasnet0_5, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_resnet18', + model_fn=tm.resnet18, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_regnet_x_16gf', + model_fn=tm.regnet_x_16gf, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_resnext50_32x4d', + model_fn=tm.resnext50_32x4d, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_shufflenet_v2_x0_5', + model_fn=tm.shufflenet_v2_x0_5, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_squeezenet1_0', + model_fn=tm.squeezenet1_0, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + +model_zoo.register(name='torchvision_vgg11', + model_fn=tm.vgg11, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_wide_resnet50_2', + model_fn=tm.wide_resnet50_2, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + +if version.parse(torchvision.__version__) >= version.parse('0.12.0'): + model_zoo.register(name='torchvision_vit_b_16', + model_fn=tm.vit_b_16, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + model_zoo.register(name='torchvision_convnext_base', + model_fn=tm.convnext_base, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_stochastic_depth_prob=True)) + +if version.parse(torchvision.__version__) >= version.parse('0.13.0'): + model_zoo.register( + name='torchvision_swin_s', + model_fn=swin_s, + data_gen_fn=data_gen_fn, + output_transform_fn=swin_s_output_output_transform_fn, + ) + model_zoo.register(name='torchvision_efficientnet_v2_s', + model_fn=tm.efficientnet_v2_s, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_stochastic_depth_prob=True)) diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py index 2a6c6ae16..455638818 100644 --- a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py +++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py @@ -1,44 +1,43 @@ import torch -import torchvision -import torchvision.models as tm -from packaging import version from colossalai.fx import symbolic_trace +from tests.kit.model_zoo import model_zoo def test_torchvision_models(): - MODEL_LIST = [ - tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, - tm.regnet_x_16gf, tm.mnasnet0_5, tm.efficientnet_b0 - ] - - RANDOMIZED_MODELS = [tm.efficientnet_b0] - - if version.parse(torchvision.__version__) >= version.parse('0.12.0'): - MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small]) - RANDOMIZED_MODELS.append(tm.convnext_small) - torch.backends.cudnn.deterministic = True + tv_sub_registry = model_zoo.get_sub_registry('torchvision') - data = torch.rand(2, 3, 224, 224) + for name, (model_fn, data_gen_fn, output_transform_fn, model_attribute) in tv_sub_registry.items(): + data = data_gen_fn() - for model_cls in MODEL_LIST: - if model_cls in RANDOMIZED_MODELS: - # remove the impact of randomicity - model = model_cls(stochastic_depth_prob=0) + if model_attribute is not None and model_attribute.has_stochastic_depth_prob: + model = model_fn(stochastic_depth_prob=0) else: - model = model_cls() + model = model_fn() gm = symbolic_trace(model) model.eval() gm.eval() - with torch.no_grad(): - fx_out = gm(data) - non_fx_out = model(data) - assert torch.allclose( - fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + try: + with torch.no_grad(): + fx_out = gm(**data) + non_fx_out = model(**data) + transformed_out = output_transform_fn(fx_out) + transformed_non_fx_out = output_transform_fn(non_fx_out) + + assert len(transformed_out) == len(transformed_non_fx_out) + + for key in transformed_out.keys(): + fx_val = transformed_out[key] + non_fx_val = transformed_non_fx_out[key] + assert torch.allclose( + fx_val, + non_fx_val), f'{model.__class__.__name__} has inconsistent outputs, {fx_val} vs {non_fx_val}' + except Exception as e: + print(name, e) if __name__ == '__main__':