mirror of https://github.com/hpcaitech/ColossalAI
99 lines
2.3 KiB
Python
99 lines
2.3 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import pytest
|
|
import torch
|
|
import torchvision.models as models
|
|
|
|
from colossalai.builder import build_model
|
|
|
|
NUM_CLS = 10
|
|
|
|
RESNET18 = dict(
|
|
type='VanillaResNet',
|
|
block_type='ResNetBasicBlock',
|
|
layers=[2, 2, 2, 2],
|
|
num_cls=NUM_CLS
|
|
)
|
|
|
|
RESNET34 = dict(
|
|
type='VanillaResNet',
|
|
block_type='ResNetBasicBlock',
|
|
layers=[3, 4, 6, 3],
|
|
num_cls=NUM_CLS
|
|
)
|
|
|
|
RESNET50 = dict(
|
|
type='VanillaResNet',
|
|
block_type='ResNetBottleneck',
|
|
layers=[3, 4, 6, 3],
|
|
num_cls=NUM_CLS
|
|
)
|
|
|
|
RESNET101 = dict(
|
|
type='VanillaResNet',
|
|
block_type='ResNetBottleneck',
|
|
layers=[3, 4, 23, 3],
|
|
num_cls=NUM_CLS
|
|
)
|
|
|
|
RESNET152 = dict(
|
|
type='VanillaResNet',
|
|
block_type='ResNetBottleneck',
|
|
layers=[3, 8, 36, 3],
|
|
num_cls=NUM_CLS
|
|
)
|
|
|
|
|
|
def compare_model(data, colossal_model, torchvision_model):
|
|
colossal_output = colossal_model(data)
|
|
torchvision_output = torchvision_model(data)
|
|
assert colossal_output[
|
|
0].shape == torchvision_output.shape, f'{colossal_output[0].shape}, {torchvision_output.shape}'
|
|
|
|
|
|
@pytest.mark.cpu
|
|
def test_vanilla_resnet():
|
|
"""Compare colossal resnet with torchvision resnet"""
|
|
# data
|
|
x = torch.randn((2, 3, 224, 224))
|
|
|
|
# resnet 18
|
|
col_resnet18 = build_model(RESNET18)
|
|
col_resnet18.build_from_cfg()
|
|
torchvision_resnet18 = models.resnet18(num_classes=NUM_CLS)
|
|
|
|
compare_model(x, col_resnet18, torchvision_resnet18)
|
|
|
|
# resnet 34
|
|
col_resnet34 = build_model(RESNET34)
|
|
col_resnet34.build_from_cfg()
|
|
torchvision_resnet34 = models.resnet34(num_classes=NUM_CLS)
|
|
|
|
compare_model(x, col_resnet34, torchvision_resnet34)
|
|
|
|
# resnet 50
|
|
col_resnet50 = build_model(RESNET50)
|
|
col_resnet50.build_from_cfg()
|
|
torchvision_resnet50 = models.resnet50(num_classes=NUM_CLS)
|
|
|
|
compare_model(x, col_resnet50, torchvision_resnet50)
|
|
|
|
# resnet 101
|
|
col_resnet101 = build_model(RESNET101)
|
|
col_resnet101.build_from_cfg()
|
|
torchvision_resnet101 = models.resnet101(num_classes=NUM_CLS)
|
|
|
|
compare_model(x, col_resnet101, torchvision_resnet101)
|
|
|
|
# # resnet 152
|
|
col_resnet152 = build_model(RESNET152)
|
|
col_resnet152.build_from_cfg()
|
|
torchvision_resnet152 = models.resnet152(num_classes=NUM_CLS)
|
|
|
|
compare_model(x, col_resnet152, torchvision_resnet152)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_vanilla_resnet()
|