[test] added timm models to test model zoo (#3129)

* [test] added timm models to test model zoo

* polish code

* polish code

* polish code

* polish code

* polish code
pull/3136/head
Frank Lee 2023-03-14 14:29:18 +08:00 committed by GitHub
parent 23cd5e2ccf
commit 86ac782d7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 258 additions and 43 deletions

0
tests/kit/__init__.py Normal file
View File

View File

@ -0,0 +1,4 @@
from . import timm
from .registry import model_zoo
__all__ = ['model_zoo']

View File

@ -0,0 +1,63 @@
#!/usr/bin/env python
from dataclasses import dataclass
from typing import Callable
__all__ = ['ModelZooRegistry', 'ModelAttributem', 'model_zoo']
@dataclass
class ModelAttribute:
"""
Attributes of a model.
"""
has_control_flow: bool = False
class ModelZooRegistry(dict):
"""
A registry to map model names to model and data generation functions.
"""
def register(self,
name: str,
model_fn: Callable,
data_gen_fn: Callable,
output_transform_fn: Callable,
model_attribute: ModelAttribute = None):
"""
Register a model and data generation function.
Examples:
>>> # Register
>>> model_zoo = ModelZooRegistry()
>>> model_zoo.register('resnet18', resnet18, resnet18_data_gen)
>>> # Run the model
>>> data = resnresnet18_data_gen() # do not input any argument
>>> model = resnet18() # do not input any argument
>>> out = model(**data)
Args:
name (str): Name of the model.
model_fn (callable): A function that returns a model. **It must not contain any arguments.**
output_transform_fn (callable): A function that transforms the output of the model into Dict.
data_gen_fn (callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
model_attribute (ModelAttribute): Attributes of the model. Defaults to None.
"""
self[name] = (model_fn, data_gen_fn, output_transform_fn, model_attribute)
def get_sub_registry(self, keyword: str):
"""
Get a sub registry with models that contain the keyword.
Args:
keyword (str): Keyword to filter models.
"""
new_dict = dict()
for k, v in self.items():
if keyword in k:
new_dict[k] = v
return new_dict
model_zoo = ModelZooRegistry()

View File

@ -0,0 +1 @@
from .timm import *

View File

@ -0,0 +1,159 @@
import timm.models as tm
import torch
from ..registry import ModelAttribute, model_zoo
## ==============
# Register models without control flow
## ==============
data_gen_fn = lambda: dict(x=torch.rand(2, 3, 224, 224))
output_transform_fn = lambda x: dict(output=x)
model_zoo.register(name='timm_resnet',
model_fn=tm.resnest.resnest50d,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_beit',
model_fn=tm.beit.beit_base_patch16_224,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_cait',
model_fn=tm.cait.cait_s24_224,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_convmixer',
model_fn=tm.convmixer.convmixer_768_32,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_efficientnetv2',
model_fn=tm.efficientnet.efficientnetv2_m,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_resmlp',
model_fn=tm.resmlp_12_224,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_vision_transformer',
model_fn=tm.vision_transformer.vit_base_patch16_224,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_deit',
model_fn=tm.deit_base_distilled_patch16_224,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_beitv2',
model_fn=tm.beitv2_base_patch16_224,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_coat',
model_fn=tm.coat.coat_lite_mini,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_deit3',
model_fn=tm.deit3_base_patch16_224,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_eca_nfnet',
model_fn=tm.eca_nfnet_l0,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_efficientformer',
model_fn=tm.efficientformer_l1,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_ese_vovnet19b_dw',
model_fn=tm.ese_vovnet19b_dw,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_gmixer_12_224',
model_fn=tm.gmixer_12_224,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_gmlp_b16_224',
model_fn=tm.gmlp_b16_224,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_hardcorenas_a',
model_fn=tm.hardcorenas_a,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_hrnet_w18_small',
model_fn=tm.hrnet_w18_small,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_inception_v3',
model_fn=tm.inception_v3,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_mixer_b16_224',
model_fn=tm.mixer_b16_224,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_nf_ecaresnet101',
model_fn=tm.nf_ecaresnet101,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_nf_regnet_b0',
model_fn=tm.nf_regnet_b0,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_regnetv_040',
model_fn=tm.regnetv_040,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_skresnet18',
model_fn=tm.skresnet18,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_tnt_b_patch16_224',
model_fn=tm.tnt_b_patch16_224,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_wide_resnet50_2',
model_fn=tm.wide_resnet50_2,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_convit',
model_fn=tm.convit_base,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='timm_dm_nfnet',
model_fn=tm.dm_nfnet_f0,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
# ==============
# Register models with control flow
# ==============
model_zoo.register(name='timm_convnext',
model_fn=tm.convnext.convnext_base,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='timm_vgg',
model_fn=tm.vgg.vgg11,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='timm_dpn',
model_fn=tm.dpn.dpn68,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='timm_densenet',
model_fn=tm.densenet.densenet121,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='timm_rexnet',
model_fn=tm.rexnet.rexnet_100,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='timm_swin_transformer',
model_fn=tm.swin_transformer.swin_base_patch4_window7_224,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))

View File

@ -3,9 +3,10 @@ import timm.models as tm
import torch
from colossalai.fx import symbolic_trace
from tests.kit.model_zoo import model_zoo
def trace_and_compare(model_cls, data, meta_args=None):
def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
# trace
model = model_cls()
@ -14,60 +15,47 @@ def trace_and_compare(model_cls, data, meta_args=None):
# without this statement, the torch.nn.functional.batch_norm will always be in training mode
model.eval()
# TODO: support the following models
# 1. ConViT
# 2. NormFreeNet
# as they are not supported, let's skip them
if model.__class__.__name__ in ['ConViT', 'NormFreeNet']:
return
gm = symbolic_trace(model, meta_args=meta_args)
# run forward
with torch.no_grad():
fx_out = gm(data)
non_fx_out = model(data)
fx_out = gm(**data)
non_fx_out = model(**data)
# compare output
if isinstance(fx_out, tuple):
# some models produce tuple as output
for v1, v2 in zip(fx_out, non_fx_out):
assert torch.allclose(v1, v2), f'{model.__class__.__name__} has inconsistent outputs, {v1} vs {v2}'
else:
assert torch.allclose(
fx_out, non_fx_out,
atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
transformed_fx_out = output_transform_fn(fx_out)
transformed_non_fx_out = output_transform_fn(non_fx_out)
assert len(transformed_fx_out) == len(transformed_non_fx_out)
for key in transformed_fx_out.keys():
fx_output_val = transformed_fx_out[key]
non_fx_output_val = transformed_non_fx_out[key]
assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \
f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
def test_timm_models_without_control_flow():
def test_timm_models():
torch.backends.cudnn.deterministic = True
MODEL_LIST = [
tm.resnest.resnest50d,
tm.beit.beit_base_patch16_224,
tm.cait.cait_s24_224,
tm.convmixer.convmixer_768_32,
tm.efficientnet.efficientnetv2_m,
tm.resmlp_12_224,
tm.vision_transformer.vit_base_patch16_224,
tm.deit_base_distilled_patch16_224,
]
sub_model_zoo = model_zoo.get_sub_registry('timm')
data = torch.rand(2, 3, 224, 224)
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
data = data_gen_fn()
if attribute is not None and attribute.has_control_flow:
meta_args = {k: v.to('meta') for k, v in data.items()}
else:
meta_args = None
for model_cls in MODEL_LIST:
trace_and_compare(model_cls, data)
def test_timm_models_with_control_flow():
torch.backends.cudnn.deterministic = True
MODEL_LIST_WITH_CONTROL_FLOW = [
tm.convnext.convnext_base, tm.vgg.vgg11, tm.dpn.dpn68, tm.densenet.densenet121, tm.rexnet.rexnet_100,
tm.swin_transformer.swin_base_patch4_window7_224
]
data = torch.rand(2, 3, 224, 224)
meta_args = {'x': data.to('meta')}
for model_cls in MODEL_LIST_WITH_CONTROL_FLOW:
trace_and_compare(model_cls, data, meta_args)
trace_and_compare(model_fn, data, output_transform_fn, meta_args)
if __name__ == '__main__':
test_timm_models_with_control_flow()
test_timm_models_without_control_flow()
test_timm_models()