mirror of https://github.com/hpcaitech/ColossalAI
[test] added torchvision models to test model zoo (#3132)
* [test] added torchvision models to test model zoo * polish code * polish code * polish code * polish code * polish code * polish codepull/3135/head
parent
1216d1e7bd
commit
a674c63348
|
@ -1,4 +1,4 @@
|
|||
from . import diffusers, timm
|
||||
from . import diffusers, timm, torchvision
|
||||
from .registry import model_zoo
|
||||
|
||||
__all__ = ['model_zoo']
|
||||
|
|
|
@ -9,8 +9,13 @@ __all__ = ['ModelZooRegistry', 'ModelAttributem', 'model_zoo']
|
|||
class ModelAttribute:
|
||||
"""
|
||||
Attributes of a model.
|
||||
|
||||
Args:
|
||||
has_control_flow (bool): Whether the model contains branching in its forward method.
|
||||
has_stochastic_depth_prob (bool): Whether the model contains stochastic depth probability. Often seen in the torchvision models.
|
||||
"""
|
||||
has_control_flow: bool = False
|
||||
has_stochastic_depth_prob: bool = False
|
||||
|
||||
|
||||
class ModelZooRegistry(dict):
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .torchvision import *
|
|
@ -0,0 +1,131 @@
|
|||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.models as tm
|
||||
from packaging import version
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
data_gen_fn = lambda: dict(x=torch.rand(4, 3, 224, 224))
|
||||
output_transform_fn = lambda x: dict(output=x)
|
||||
|
||||
# special data gen fn
|
||||
inception_v3_data_gen_fn = lambda: dict(x=torch.rand(4, 3, 299, 299))
|
||||
|
||||
|
||||
# special model fn
|
||||
def swin_s():
|
||||
from torchvision.models.swin_transformer import Swin_T_Weights, _swin_transformer
|
||||
|
||||
# adapted from torchvision.models.swin_transformer.swin_small
|
||||
weights = None
|
||||
weights = Swin_T_Weights.verify(weights)
|
||||
progress = True
|
||||
|
||||
return _swin_transformer(
|
||||
patch_size=[4, 4],
|
||||
embed_dim=96,
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=[7, 7],
|
||||
stochastic_depth_prob=0, # it is originally 0.2, but we set it to 0 to make it deterministic
|
||||
weights=weights,
|
||||
progress=progress,
|
||||
)
|
||||
|
||||
|
||||
# special output transform fn
|
||||
google_net_output_transform_fn = lambda x: dict(output=x.logits) if isinstance(x, torchvision.models.GoogLeNetOutputs
|
||||
) else dict(output=x)
|
||||
swin_s_output_output_transform_fn = lambda x: {f'output{idx}': val
|
||||
for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
|
||||
inception_v3_output_transform_fn = lambda x: dict(output=x.logits) if isinstance(x, torchvision.models.InceptionOutputs
|
||||
) else dict(output=x)
|
||||
|
||||
model_zoo.register(name='torchvision_alexnet',
|
||||
model_fn=tm.alexnet,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_densenet121',
|
||||
model_fn=tm.densenet121,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_efficientnet_b0',
|
||||
model_fn=tm.efficientnet_b0,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_stochastic_depth_prob=True))
|
||||
model_zoo.register(name='torchvision_googlenet',
|
||||
model_fn=tm.googlenet,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=google_net_output_transform_fn)
|
||||
model_zoo.register(name='torchvision_inception_v3',
|
||||
model_fn=tm.inception_v3,
|
||||
data_gen_fn=inception_v3_data_gen_fn,
|
||||
output_transform_fn=inception_v3_output_transform_fn)
|
||||
model_zoo.register(name='torchvision_mobilenet_v2',
|
||||
model_fn=tm.mobilenet_v2,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_mobilenet_v3_small',
|
||||
model_fn=tm.mobilenet_v3_small,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_mnasnet0_5',
|
||||
model_fn=tm.mnasnet0_5,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_resnet18',
|
||||
model_fn=tm.resnet18,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_regnet_x_16gf',
|
||||
model_fn=tm.regnet_x_16gf,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_resnext50_32x4d',
|
||||
model_fn=tm.resnext50_32x4d,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_shufflenet_v2_x0_5',
|
||||
model_fn=tm.shufflenet_v2_x0_5,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_squeezenet1_0',
|
||||
model_fn=tm.squeezenet1_0,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
|
||||
model_zoo.register(name='torchvision_vgg11',
|
||||
model_fn=tm.vgg11,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_wide_resnet50_2',
|
||||
model_fn=tm.wide_resnet50_2,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
|
||||
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
|
||||
model_zoo.register(name='torchvision_vit_b_16',
|
||||
model_fn=tm.vit_b_16,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_convnext_base',
|
||||
model_fn=tm.convnext_base,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_stochastic_depth_prob=True))
|
||||
|
||||
if version.parse(torchvision.__version__) >= version.parse('0.13.0'):
|
||||
model_zoo.register(
|
||||
name='torchvision_swin_s',
|
||||
model_fn=swin_s,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=swin_s_output_output_transform_fn,
|
||||
)
|
||||
model_zoo.register(name='torchvision_efficientnet_v2_s',
|
||||
model_fn=tm.efficientnet_v2_s,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_stochastic_depth_prob=True))
|
|
@ -1,44 +1,43 @@
|
|||
import torch
|
||||
import torchvision
|
||||
import torchvision.models as tm
|
||||
from packaging import version
|
||||
|
||||
from colossalai.fx import symbolic_trace
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
def test_torchvision_models():
|
||||
MODEL_LIST = [
|
||||
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
|
||||
tm.regnet_x_16gf, tm.mnasnet0_5, tm.efficientnet_b0
|
||||
]
|
||||
|
||||
RANDOMIZED_MODELS = [tm.efficientnet_b0]
|
||||
|
||||
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
|
||||
MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small])
|
||||
RANDOMIZED_MODELS.append(tm.convnext_small)
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
tv_sub_registry = model_zoo.get_sub_registry('torchvision')
|
||||
|
||||
data = torch.rand(2, 3, 224, 224)
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, model_attribute) in tv_sub_registry.items():
|
||||
data = data_gen_fn()
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
if model_cls in RANDOMIZED_MODELS:
|
||||
# remove the impact of randomicity
|
||||
model = model_cls(stochastic_depth_prob=0)
|
||||
if model_attribute is not None and model_attribute.has_stochastic_depth_prob:
|
||||
model = model_fn(stochastic_depth_prob=0)
|
||||
else:
|
||||
model = model_cls()
|
||||
model = model_fn()
|
||||
|
||||
gm = symbolic_trace(model)
|
||||
|
||||
model.eval()
|
||||
gm.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
fx_out = gm(data)
|
||||
non_fx_out = model(data)
|
||||
assert torch.allclose(
|
||||
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||
try:
|
||||
with torch.no_grad():
|
||||
fx_out = gm(**data)
|
||||
non_fx_out = model(**data)
|
||||
transformed_out = output_transform_fn(fx_out)
|
||||
transformed_non_fx_out = output_transform_fn(non_fx_out)
|
||||
|
||||
assert len(transformed_out) == len(transformed_non_fx_out)
|
||||
|
||||
for key in transformed_out.keys():
|
||||
fx_val = transformed_out[key]
|
||||
non_fx_val = transformed_non_fx_out[key]
|
||||
assert torch.allclose(
|
||||
fx_val,
|
||||
non_fx_val), f'{model.__class__.__name__} has inconsistent outputs, {fx_val} vs {non_fx_val}'
|
||||
except Exception as e:
|
||||
print(name, e)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue