You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/kit/model_zoo/torchvision/torchvision.py

132 lines
5.8 KiB

from collections import namedtuple
import torch
import torchvision
import torchvision.models as tm
from packaging import version
from ..registry import ModelAttribute, model_zoo
data_gen_fn = lambda: dict(x=torch.rand(4, 3, 224, 224))
output_transform_fn = lambda x: dict(output=x)
# special data gen fn
inception_v3_data_gen_fn = lambda: dict(x=torch.rand(4, 3, 299, 299))
# special model fn
def swin_s():
from torchvision.models.swin_transformer import Swin_T_Weights, _swin_transformer
# adapted from torchvision.models.swin_transformer.swin_small
weights = None
weights = Swin_T_Weights.verify(weights)
progress = True
return _swin_transformer(
patch_size=[4, 4],
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=[7, 7],
stochastic_depth_prob=0, # it is originally 0.2, but we set it to 0 to make it deterministic
weights=weights,
progress=progress,
)
# special output transform fn
google_net_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs
) else dict(output=x)
swin_s_output_output_transform_fn = lambda x: {f'output{idx}': val
for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
inception_v3_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs
) else dict(output=x)
model_zoo.register(name='torchvision_alexnet',
model_fn=tm.alexnet,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_densenet121',
model_fn=tm.densenet121,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_efficientnet_b0',
model_fn=tm.efficientnet_b0,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_stochastic_depth_prob=True))
model_zoo.register(name='torchvision_googlenet',
model_fn=tm.googlenet,
data_gen_fn=data_gen_fn,
output_transform_fn=google_net_output_transform_fn)
model_zoo.register(name='torchvision_inception_v3',
model_fn=tm.inception_v3,
data_gen_fn=inception_v3_data_gen_fn,
output_transform_fn=inception_v3_output_transform_fn)
model_zoo.register(name='torchvision_mobilenet_v2',
model_fn=tm.mobilenet_v2,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_mobilenet_v3_small',
model_fn=tm.mobilenet_v3_small,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_mnasnet0_5',
model_fn=tm.mnasnet0_5,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_resnet18',
model_fn=tm.resnet18,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_regnet_x_16gf',
model_fn=tm.regnet_x_16gf,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_resnext50_32x4d',
model_fn=tm.resnext50_32x4d,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_shufflenet_v2_x0_5',
model_fn=tm.shufflenet_v2_x0_5,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_squeezenet1_0',
model_fn=tm.squeezenet1_0,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_vgg11',
model_fn=tm.vgg11,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_wide_resnet50_2',
model_fn=tm.wide_resnet50_2,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
model_zoo.register(name='torchvision_vit_b_16',
model_fn=tm.vit_b_16,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_convnext_base',
model_fn=tm.convnext_base,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_stochastic_depth_prob=True))
if version.parse(torchvision.__version__) >= version.parse('0.13.0'):
model_zoo.register(
name='torchvision_swin_s',
model_fn=swin_s,
data_gen_fn=data_gen_fn,
output_transform_fn=swin_s_output_output_transform_fn,
)
model_zoo.register(name='torchvision_efficientnet_v2_s',
model_fn=tm.efficientnet_v2_s,
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_stochastic_depth_prob=True))