mirror of https://github.com/hpcaitech/ColossalAI
[test] added timm models to test model zoo (#3129)
* [test] added timm models to test model zoo * polish code * polish code * polish code * polish code * polish codepull/3136/head
parent
23cd5e2ccf
commit
86ac782d7c
@ -0,0 +1,4 @@
|
||||
from . import timm
|
||||
from .registry import model_zoo
|
||||
|
||||
__all__ = ['model_zoo']
|
@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
__all__ = ['ModelZooRegistry', 'ModelAttributem', 'model_zoo']
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelAttribute:
|
||||
"""
|
||||
Attributes of a model.
|
||||
"""
|
||||
has_control_flow: bool = False
|
||||
|
||||
|
||||
class ModelZooRegistry(dict):
|
||||
"""
|
||||
A registry to map model names to model and data generation functions.
|
||||
"""
|
||||
|
||||
def register(self,
|
||||
name: str,
|
||||
model_fn: Callable,
|
||||
data_gen_fn: Callable,
|
||||
output_transform_fn: Callable,
|
||||
model_attribute: ModelAttribute = None):
|
||||
"""
|
||||
Register a model and data generation function.
|
||||
|
||||
Examples:
|
||||
>>> # Register
|
||||
>>> model_zoo = ModelZooRegistry()
|
||||
>>> model_zoo.register('resnet18', resnet18, resnet18_data_gen)
|
||||
>>> # Run the model
|
||||
>>> data = resnresnet18_data_gen() # do not input any argument
|
||||
>>> model = resnet18() # do not input any argument
|
||||
>>> out = model(**data)
|
||||
|
||||
Args:
|
||||
name (str): Name of the model.
|
||||
model_fn (callable): A function that returns a model. **It must not contain any arguments.**
|
||||
output_transform_fn (callable): A function that transforms the output of the model into Dict.
|
||||
data_gen_fn (callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
|
||||
model_attribute (ModelAttribute): Attributes of the model. Defaults to None.
|
||||
"""
|
||||
self[name] = (model_fn, data_gen_fn, output_transform_fn, model_attribute)
|
||||
|
||||
def get_sub_registry(self, keyword: str):
|
||||
"""
|
||||
Get a sub registry with models that contain the keyword.
|
||||
|
||||
Args:
|
||||
keyword (str): Keyword to filter models.
|
||||
"""
|
||||
new_dict = dict()
|
||||
|
||||
for k, v in self.items():
|
||||
if keyword in k:
|
||||
new_dict[k] = v
|
||||
return new_dict
|
||||
|
||||
|
||||
model_zoo = ModelZooRegistry()
|
@ -0,0 +1 @@
|
||||
from .timm import *
|
@ -0,0 +1,159 @@
|
||||
import timm.models as tm
|
||||
import torch
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
## ==============
|
||||
# Register models without control flow
|
||||
## ==============
|
||||
data_gen_fn = lambda: dict(x=torch.rand(2, 3, 224, 224))
|
||||
output_transform_fn = lambda x: dict(output=x)
|
||||
|
||||
model_zoo.register(name='timm_resnet',
|
||||
model_fn=tm.resnest.resnest50d,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_beit',
|
||||
model_fn=tm.beit.beit_base_patch16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_cait',
|
||||
model_fn=tm.cait.cait_s24_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_convmixer',
|
||||
model_fn=tm.convmixer.convmixer_768_32,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_efficientnetv2',
|
||||
model_fn=tm.efficientnet.efficientnetv2_m,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_resmlp',
|
||||
model_fn=tm.resmlp_12_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_vision_transformer',
|
||||
model_fn=tm.vision_transformer.vit_base_patch16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_deit',
|
||||
model_fn=tm.deit_base_distilled_patch16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_beitv2',
|
||||
model_fn=tm.beitv2_base_patch16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_coat',
|
||||
model_fn=tm.coat.coat_lite_mini,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
|
||||
model_zoo.register(name='timm_deit3',
|
||||
model_fn=tm.deit3_base_patch16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
|
||||
model_zoo.register(name='timm_eca_nfnet',
|
||||
model_fn=tm.eca_nfnet_l0,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_efficientformer',
|
||||
model_fn=tm.efficientformer_l1,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_ese_vovnet19b_dw',
|
||||
model_fn=tm.ese_vovnet19b_dw,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_gmixer_12_224',
|
||||
model_fn=tm.gmixer_12_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_gmlp_b16_224',
|
||||
model_fn=tm.gmlp_b16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_hardcorenas_a',
|
||||
model_fn=tm.hardcorenas_a,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_hrnet_w18_small',
|
||||
model_fn=tm.hrnet_w18_small,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_inception_v3',
|
||||
model_fn=tm.inception_v3,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_mixer_b16_224',
|
||||
model_fn=tm.mixer_b16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_nf_ecaresnet101',
|
||||
model_fn=tm.nf_ecaresnet101,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_nf_regnet_b0',
|
||||
model_fn=tm.nf_regnet_b0,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_regnetv_040',
|
||||
model_fn=tm.regnetv_040,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_skresnet18',
|
||||
model_fn=tm.skresnet18,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_tnt_b_patch16_224',
|
||||
model_fn=tm.tnt_b_patch16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_wide_resnet50_2',
|
||||
model_fn=tm.wide_resnet50_2,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_convit',
|
||||
model_fn=tm.convit_base,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_dm_nfnet',
|
||||
model_fn=tm.dm_nfnet_f0,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
|
||||
# ==============
|
||||
# Register models with control flow
|
||||
# ==============
|
||||
model_zoo.register(name='timm_convnext',
|
||||
model_fn=tm.convnext.convnext_base,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='timm_vgg',
|
||||
model_fn=tm.vgg.vgg11,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='timm_dpn',
|
||||
model_fn=tm.dpn.dpn68,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='timm_densenet',
|
||||
model_fn=tm.densenet.densenet121,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='timm_rexnet',
|
||||
model_fn=tm.rexnet.rexnet_100,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='timm_swin_transformer',
|
||||
model_fn=tm.swin_transformer.swin_base_patch4_window7_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
Loading…
Reference in new issue