[test] added torchvision models to test model zoo (#3132)

* [test] added torchvision models to test model zoo

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
pull/3135/head
Frank Lee 2023-03-15 10:42:07 +08:00 committed by GitHub
parent 1216d1e7bd
commit a674c63348
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 162 additions and 26 deletions

View File

@ -1,4 +1,4 @@
from . import diffusers, timm
from . import diffusers, timm, torchvision
from .registry import model_zoo
__all__ = ['model_zoo']

View File

@ -9,8 +9,13 @@ __all__ = ['ModelZooRegistry', 'ModelAttributem', 'model_zoo']
class ModelAttribute:
"""
Attributes of a model.
Args:
has_control_flow (bool): Whether the model contains branching in its forward method.
has_stochastic_depth_prob (bool): Whether the model contains stochastic depth probability. Often seen in the torchvision models.
"""
has_control_flow: bool = False
has_stochastic_depth_prob: bool = False
class ModelZooRegistry(dict):

View File

@ -0,0 +1 @@
from .torchvision import *

View File

@ -0,0 +1,131 @@
from collections import namedtuple
import torch
import torchvision
import torchvision.models as tm
from packaging import version
from ..registry import ModelAttribute, model_zoo
data_gen_fn = lambda: dict(x=torch.rand(4, 3, 224, 224))
output_transform_fn = lambda x: dict(output=x)
# special data gen fn
inception_v3_data_gen_fn = lambda: dict(x=torch.rand(4, 3, 299, 299))
# special model fn
def swin_s():
from torchvision.models.swin_transformer import Swin_T_Weights, _swin_transformer
# adapted from torchvision.models.swin_transformer.swin_small
weights = None
weights = Swin_T_Weights.verify(weights)
progress = True
return _swin_transformer(
patch_size=[4, 4],
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=[7, 7],
stochastic_depth_prob=0, # it is originally 0.2, but we set it to 0 to make it deterministic
weights=weights,
progress=progress,
)
# special output transform fn
google_net_output_transform_fn = lambda x: dict(output=x.logits) if isinstance(x, torchvision.models.GoogLeNetOutputs
) else dict(output=x)
swin_s_output_output_transform_fn = lambda x: {f'output{idx}': val
for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
inception_v3_output_transform_fn = lambda x: dict(output=x.logits) if isinstance(x, torchvision.models.InceptionOutputs
) else dict(output=x)
model_zoo.register(name='torchvision_alexnet',
model_fn=tm.alexnet,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_densenet121',
model_fn=tm.densenet121,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_efficientnet_b0',
model_fn=tm.efficientnet_b0,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_stochastic_depth_prob=True))
model_zoo.register(name='torchvision_googlenet',
model_fn=tm.googlenet,
data_gen_fn=data_gen_fn,
output_transform_fn=google_net_output_transform_fn)
model_zoo.register(name='torchvision_inception_v3',
model_fn=tm.inception_v3,
data_gen_fn=inception_v3_data_gen_fn,
output_transform_fn=inception_v3_output_transform_fn)
model_zoo.register(name='torchvision_mobilenet_v2',
model_fn=tm.mobilenet_v2,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_mobilenet_v3_small',
model_fn=tm.mobilenet_v3_small,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_mnasnet0_5',
model_fn=tm.mnasnet0_5,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_resnet18',
model_fn=tm.resnet18,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_regnet_x_16gf',
model_fn=tm.regnet_x_16gf,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_resnext50_32x4d',
model_fn=tm.resnext50_32x4d,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_shufflenet_v2_x0_5',
model_fn=tm.shufflenet_v2_x0_5,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_squeezenet1_0',
model_fn=tm.squeezenet1_0,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_vgg11',
model_fn=tm.vgg11,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_wide_resnet50_2',
model_fn=tm.wide_resnet50_2,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
model_zoo.register(name='torchvision_vit_b_16',
model_fn=tm.vit_b_16,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_convnext_base',
model_fn=tm.convnext_base,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_stochastic_depth_prob=True))
if version.parse(torchvision.__version__) >= version.parse('0.13.0'):
model_zoo.register(
name='torchvision_swin_s',
model_fn=swin_s,
data_gen_fn=data_gen_fn,
output_transform_fn=swin_s_output_output_transform_fn,
)
model_zoo.register(name='torchvision_efficientnet_v2_s',
model_fn=tm.efficientnet_v2_s,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_stochastic_depth_prob=True))

View File

@ -1,44 +1,43 @@
import torch
import torchvision
import torchvision.models as tm
from packaging import version
from colossalai.fx import symbolic_trace
from tests.kit.model_zoo import model_zoo
def test_torchvision_models():
MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
tm.regnet_x_16gf, tm.mnasnet0_5, tm.efficientnet_b0
]
RANDOMIZED_MODELS = [tm.efficientnet_b0]
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small])
RANDOMIZED_MODELS.append(tm.convnext_small)
torch.backends.cudnn.deterministic = True
tv_sub_registry = model_zoo.get_sub_registry('torchvision')
data = torch.rand(2, 3, 224, 224)
for name, (model_fn, data_gen_fn, output_transform_fn, model_attribute) in tv_sub_registry.items():
data = data_gen_fn()
for model_cls in MODEL_LIST:
if model_cls in RANDOMIZED_MODELS:
# remove the impact of randomicity
model = model_cls(stochastic_depth_prob=0)
if model_attribute is not None and model_attribute.has_stochastic_depth_prob:
model = model_fn(stochastic_depth_prob=0)
else:
model = model_cls()
model = model_fn()
gm = symbolic_trace(model)
model.eval()
gm.eval()
with torch.no_grad():
fx_out = gm(data)
non_fx_out = model(data)
assert torch.allclose(
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
try:
with torch.no_grad():
fx_out = gm(**data)
non_fx_out = model(**data)
transformed_out = output_transform_fn(fx_out)
transformed_non_fx_out = output_transform_fn(non_fx_out)
assert len(transformed_out) == len(transformed_non_fx_out)
for key in transformed_out.keys():
fx_val = transformed_out[key]
non_fx_val = transformed_non_fx_out[key]
assert torch.allclose(
fx_val,
non_fx_val), f'{model.__class__.__name__} has inconsistent outputs, {fx_val} vs {non_fx_val}'
except Exception as e:
print(name, e)
if __name__ == '__main__':