mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* [test] added timm models to test model zoo * polish code * polish code * polish code * polish code * polish codepull/3136/head
Frank Lee
2 years ago
committed by
GitHub
6 changed files with 256 additions and 41 deletions
@ -0,0 +1,4 @@
|
||||
from . import timm |
||||
from .registry import model_zoo |
||||
|
||||
__all__ = ['model_zoo'] |
@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python |
||||
from dataclasses import dataclass |
||||
from typing import Callable |
||||
|
||||
__all__ = ['ModelZooRegistry', 'ModelAttributem', 'model_zoo'] |
||||
|
||||
|
||||
@dataclass |
||||
class ModelAttribute: |
||||
""" |
||||
Attributes of a model. |
||||
""" |
||||
has_control_flow: bool = False |
||||
|
||||
|
||||
class ModelZooRegistry(dict): |
||||
""" |
||||
A registry to map model names to model and data generation functions. |
||||
""" |
||||
|
||||
def register(self, |
||||
name: str, |
||||
model_fn: Callable, |
||||
data_gen_fn: Callable, |
||||
output_transform_fn: Callable, |
||||
model_attribute: ModelAttribute = None): |
||||
""" |
||||
Register a model and data generation function. |
||||
|
||||
Examples: |
||||
>>> # Register |
||||
>>> model_zoo = ModelZooRegistry() |
||||
>>> model_zoo.register('resnet18', resnet18, resnet18_data_gen) |
||||
>>> # Run the model |
||||
>>> data = resnresnet18_data_gen() # do not input any argument |
||||
>>> model = resnet18() # do not input any argument |
||||
>>> out = model(**data) |
||||
|
||||
Args: |
||||
name (str): Name of the model. |
||||
model_fn (callable): A function that returns a model. **It must not contain any arguments.** |
||||
output_transform_fn (callable): A function that transforms the output of the model into Dict. |
||||
data_gen_fn (callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.** |
||||
model_attribute (ModelAttribute): Attributes of the model. Defaults to None. |
||||
""" |
||||
self[name] = (model_fn, data_gen_fn, output_transform_fn, model_attribute) |
||||
|
||||
def get_sub_registry(self, keyword: str): |
||||
""" |
||||
Get a sub registry with models that contain the keyword. |
||||
|
||||
Args: |
||||
keyword (str): Keyword to filter models. |
||||
""" |
||||
new_dict = dict() |
||||
|
||||
for k, v in self.items(): |
||||
if keyword in k: |
||||
new_dict[k] = v |
||||
return new_dict |
||||
|
||||
|
||||
model_zoo = ModelZooRegistry() |
@ -0,0 +1,159 @@
|
||||
import timm.models as tm |
||||
import torch |
||||
|
||||
from ..registry import ModelAttribute, model_zoo |
||||
|
||||
## ============== |
||||
# Register models without control flow |
||||
## ============== |
||||
data_gen_fn = lambda: dict(x=torch.rand(2, 3, 224, 224)) |
||||
output_transform_fn = lambda x: dict(output=x) |
||||
|
||||
model_zoo.register(name='timm_resnet', |
||||
model_fn=tm.resnest.resnest50d, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_beit', |
||||
model_fn=tm.beit.beit_base_patch16_224, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_cait', |
||||
model_fn=tm.cait.cait_s24_224, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_convmixer', |
||||
model_fn=tm.convmixer.convmixer_768_32, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_efficientnetv2', |
||||
model_fn=tm.efficientnet.efficientnetv2_m, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_resmlp', |
||||
model_fn=tm.resmlp_12_224, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_vision_transformer', |
||||
model_fn=tm.vision_transformer.vit_base_patch16_224, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_deit', |
||||
model_fn=tm.deit_base_distilled_patch16_224, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_beitv2', |
||||
model_fn=tm.beitv2_base_patch16_224, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_coat', |
||||
model_fn=tm.coat.coat_lite_mini, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
|
||||
model_zoo.register(name='timm_deit3', |
||||
model_fn=tm.deit3_base_patch16_224, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
|
||||
model_zoo.register(name='timm_eca_nfnet', |
||||
model_fn=tm.eca_nfnet_l0, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_efficientformer', |
||||
model_fn=tm.efficientformer_l1, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_ese_vovnet19b_dw', |
||||
model_fn=tm.ese_vovnet19b_dw, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_gmixer_12_224', |
||||
model_fn=tm.gmixer_12_224, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_gmlp_b16_224', |
||||
model_fn=tm.gmlp_b16_224, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_hardcorenas_a', |
||||
model_fn=tm.hardcorenas_a, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_hrnet_w18_small', |
||||
model_fn=tm.hrnet_w18_small, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_inception_v3', |
||||
model_fn=tm.inception_v3, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_mixer_b16_224', |
||||
model_fn=tm.mixer_b16_224, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_nf_ecaresnet101', |
||||
model_fn=tm.nf_ecaresnet101, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_nf_regnet_b0', |
||||
model_fn=tm.nf_regnet_b0, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_regnetv_040', |
||||
model_fn=tm.regnetv_040, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_skresnet18', |
||||
model_fn=tm.skresnet18, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_tnt_b_patch16_224', |
||||
model_fn=tm.tnt_b_patch16_224, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_wide_resnet50_2', |
||||
model_fn=tm.wide_resnet50_2, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_convit', |
||||
model_fn=tm.convit_base, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
model_zoo.register(name='timm_dm_nfnet', |
||||
model_fn=tm.dm_nfnet_f0, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn) |
||||
|
||||
# ============== |
||||
# Register models with control flow |
||||
# ============== |
||||
model_zoo.register(name='timm_convnext', |
||||
model_fn=tm.convnext.convnext_base, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn, |
||||
model_attribute=ModelAttribute(has_control_flow=True)) |
||||
model_zoo.register(name='timm_vgg', |
||||
model_fn=tm.vgg.vgg11, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn, |
||||
model_attribute=ModelAttribute(has_control_flow=True)) |
||||
model_zoo.register(name='timm_dpn', |
||||
model_fn=tm.dpn.dpn68, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn, |
||||
model_attribute=ModelAttribute(has_control_flow=True)) |
||||
model_zoo.register(name='timm_densenet', |
||||
model_fn=tm.densenet.densenet121, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn, |
||||
model_attribute=ModelAttribute(has_control_flow=True)) |
||||
model_zoo.register(name='timm_rexnet', |
||||
model_fn=tm.rexnet.rexnet_100, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn, |
||||
model_attribute=ModelAttribute(has_control_flow=True)) |
||||
model_zoo.register(name='timm_swin_transformer', |
||||
model_fn=tm.swin_transformer.swin_base_patch4_window7_224, |
||||
data_gen_fn=data_gen_fn, |
||||
output_transform_fn=output_transform_fn, |
||||
model_attribute=ModelAttribute(has_control_flow=True)) |
Loading…
Reference in new issue