ColossalAI/tests/test_models/test_vanilla_resnet/test_vanilla_resnet.py

99 lines
2.3 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pytest
import torch
import torchvision.models as models
from colossalai.builder import build_model
NUM_CLS = 10
RESNET18 = dict(
type='VanillaResNet',
block_type='ResNetBasicBlock',
layers=[2, 2, 2, 2],
num_cls=NUM_CLS
)
RESNET34 = dict(
type='VanillaResNet',
block_type='ResNetBasicBlock',
layers=[3, 4, 6, 3],
num_cls=NUM_CLS
)
RESNET50 = dict(
type='VanillaResNet',
block_type='ResNetBottleneck',
layers=[3, 4, 6, 3],
num_cls=NUM_CLS
)
RESNET101 = dict(
type='VanillaResNet',
block_type='ResNetBottleneck',
layers=[3, 4, 23, 3],
num_cls=NUM_CLS
)
RESNET152 = dict(
type='VanillaResNet',
block_type='ResNetBottleneck',
layers=[3, 8, 36, 3],
num_cls=NUM_CLS
)
def compare_model(data, colossal_model, torchvision_model):
colossal_output = colossal_model(data)
torchvision_output = torchvision_model(data)
assert colossal_output[
0].shape == torchvision_output.shape, f'{colossal_output[0].shape}, {torchvision_output.shape}'
@pytest.mark.cpu
def test_vanilla_resnet():
"""Compare colossal resnet with torchvision resnet"""
# data
x = torch.randn((2, 3, 224, 224))
# resnet 18
col_resnet18 = build_model(RESNET18)
col_resnet18.build_from_cfg()
torchvision_resnet18 = models.resnet18(num_classes=NUM_CLS)
compare_model(x, col_resnet18, torchvision_resnet18)
# resnet 34
col_resnet34 = build_model(RESNET34)
col_resnet34.build_from_cfg()
torchvision_resnet34 = models.resnet34(num_classes=NUM_CLS)
compare_model(x, col_resnet34, torchvision_resnet34)
# resnet 50
col_resnet50 = build_model(RESNET50)
col_resnet50.build_from_cfg()
torchvision_resnet50 = models.resnet50(num_classes=NUM_CLS)
compare_model(x, col_resnet50, torchvision_resnet50)
# resnet 101
col_resnet101 = build_model(RESNET101)
col_resnet101.build_from_cfg()
torchvision_resnet101 = models.resnet101(num_classes=NUM_CLS)
compare_model(x, col_resnet101, torchvision_resnet101)
# # resnet 152
col_resnet152 = build_model(RESNET152)
col_resnet152.build_from_cfg()
torchvision_resnet152 = models.resnet152(num_classes=NUM_CLS)
compare_model(x, col_resnet152, torchvision_resnet152)
if __name__ == '__main__':
test_vanilla_resnet()