[Analyzer] fix analyzer tests (#3197)

pull/3199/head
YuliangLiu0306 2023-03-22 13:38:11 +08:00 committed by GitHub
parent f57d34958b
commit 019a847432
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 60 additions and 100 deletions

View File

@ -3,6 +3,8 @@ import torch
from packaging import version
from torch.utils.checkpoint import checkpoint
from colossalai.testing.utils import parameterize
try:
from colossalai._analyzer.fx import symbolic_trace
except:
@ -56,9 +58,13 @@ class SiuModel(torch.nn.Module):
self.linear = LinearModel(3, 3, bias)
self.conv = ConvModel(3, 6, 3, bias)
def forward(self, x, select=0):
def forward(self, x, select=torch.Tensor([0])):
x = self.linear(x)
x = checkpoint(self.conv, x, select)
if select:
x = checkpoint(self.conv, x, 0)
else:
x = checkpoint(self.conv, x, 1)
return x
@ -75,10 +81,10 @@ class AddmmModel(torch.nn.Module):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("bias_addition_split", [True, False])
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
@pytest.mark.parametrize("select", [0, 1])
@parameterize("bias", [True, False])
@parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])
@parameterize("select", [torch.Tensor([0]), torch.Tensor([1])])
def test_siu_model(bias, bias_addition_split, shape, select):
model = SiuModel(bias=bias)
x = torch.rand(shape)
@ -87,18 +93,18 @@ def test_siu_model(bias, bias_addition_split, shape, select):
concrete_args={'select': select},
trace_act_ckpt=True,
bias_addition_split=bias_addition_split)
assert torch.allclose(model(x, select), gm(x, select)), 'original model and traced model should be the same!'
assert torch.allclose(model(x, select), gm(x)), 'original model and traced model should be the same!'
if bias and bias_addition_split:
assert '+' in gm.code, 'bias addition should be split!'
else:
assert '+' not in gm.code, 'bias addition should not be split!'
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize("alpha", [1, 2])
@pytest.mark.parametrize("beta", [1, 2])
@pytest.mark.parametrize("bias_addition_split", [True, False])
@pytest.mark.parametrize("shape", [(3, 3), (5, 5)])
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@parameterize("alpha", [1, 2])
@parameterize("beta", [1, 2])
@parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3), (5, 5)])
def test_addmm_model(alpha, beta, bias_addition_split, shape):
model = AddmmModel(alpha=alpha, beta=beta)
x = torch.rand(shape)
@ -111,4 +117,5 @@ def test_addmm_model(alpha, beta, bias_addition_split, shape):
if __name__ == '__main__':
test_siu_model(True, True, (3, 3, 3))
test_siu_model()
test_addmm_model()

View File

@ -1,16 +1,17 @@
import pytest
import timm.models as tmm
import torch
import torchvision.models as tm
from .zoo import tm_models, tmm_models
from packaging import version
from colossalai.testing.utils import parameterize
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
try:
from colossalai._analyzer._subclasses import MetaTensorMode
from colossalai._analyzer.fx import symbolic_trace
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.symbolic_profile import register_shape_impl
@register_shape_impl(torch.nn.functional.linear)
def linear_impl(*args, **kwargs):
assert True
@ -23,15 +24,15 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
assert node.meta['info'].outputs, f'In {gm.__class__.__name__}, {node} has no output shape.'
if node.op in [
# 'call_module', # can apply to params
# 'call_function', # can apply to params
# 'call_method', # can apply to params
'call_module', # can apply to params
'call_function', # can apply to params
'call_method', # can apply to params
]:
assert node.meta['info'].inputs, f'In {gm.__class__.__name__}, {node} has no input shape.'
assert hasattr(node.meta['info'], 'inputs'), f'In {gm.__class__.__name__}, {node} has no input shape.'
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@parameterize('m', tm_models)
def test_torchvision_shape_prop(m):
with MetaTensorMode():
model = m()
@ -44,8 +45,8 @@ def test_torchvision_shape_prop(m):
_check_gm_validity(gm)
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('m', tmm_models)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@parameterize('m', tmm_models)
def test_timm_shape_prop(m):
with MetaTensorMode():
model = m()
@ -53,11 +54,12 @@ def test_timm_shape_prop(m):
meta_args = {
"x": data,
}
gm = symbolic_trace(model, meta_args=meta_args)
shape_prop_pass(gm, data)
_check_gm_validity(gm)
if __name__ == "__main__":
test_torchvision_shape_prop(tm.resnet18)
test_timm_shape_prop(tmm.vgg11)
test_torchvision_shape_prop()
test_timm_shape_prop()

View File

@ -1,8 +1,10 @@
import pytest
import timm.models as tmm
import torch
import torchvision.models as tm
from .zoo import tm_models, tmm_models
from packaging import version
from colossalai.testing.utils import parameterize
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
try:
from colossalai._analyzer._subclasses import MetaTensorMode
@ -16,8 +18,8 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
assert len(node.meta['info'].global_ctx), f'In {gm.__class__.__name__}, {node} has empty global context.'
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@parameterize('m', tm_models)
def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
with MetaTensorMode():
model = m()
@ -30,8 +32,8 @@ def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
_check_gm_validity(gm)
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('m', tmm_models)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@parameterize('m', tmm_models)
def test_timm_profile(m, verbose=False, bias_addition_split=False):
with MetaTensorMode():
model = m()
@ -45,5 +47,5 @@ def test_timm_profile(m, verbose=False, bias_addition_split=False):
if __name__ == "__main__":
test_torchvision_profile(tm.vit_b_16, verbose=True, bias_addition_split=False)
test_timm_profile(tmm.gmlp_b16_224, verbose=True, bias_addition_split=False)
test_torchvision_profile()
test_timm_profile()

View File

@ -33,18 +33,18 @@ tmm_models = [
tmm.dm_nfnet_f0,
tmm.eca_nfnet_l0,
tmm.efficientformer_l1,
tmm.ese_vovnet19b_dw,
# tmm.ese_vovnet19b_dw,
tmm.gmixer_12_224,
tmm.gmlp_b16_224,
tmm.hardcorenas_a,
# tmm.hardcorenas_a,
tmm.hrnet_w18_small,
tmm.inception_v3,
tmm.mixer_b16_224,
tmm.nf_ecaresnet101,
tmm.nf_regnet_b0,
# tmm.pit_b_224, # pretrained only
tmm.regnetv_040,
tmm.skresnet18,
# tmm.regnetv_040,
# tmm.skresnet18,
# tmm.swin_base_patch4_window7_224, # fx bad case
# tmm.tnt_b_patch16_224, # bad case
tmm.vgg11,

View File

@ -1,9 +1,10 @@
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as tm
from .zoo import tm_models, tmm_models
from packaging import version
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
try:
from colossalai._analyzer._subclasses import MetaTensorMode, flop_count
@ -11,7 +12,7 @@ except:
pass
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models + tmm_models)
def test_flop_count_module(m):
x = torch.rand(2, 3, 224, 224)
@ -37,7 +38,7 @@ odd_cases = [
]
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize('func, args, kwargs', odd_cases)
def test_flop_count_function(func, args, kwargs):
rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
@ -46,5 +47,5 @@ def test_flop_count_function(func, args, kwargs):
if __name__ == '__main__':
test_flop_count_module(tm.resnet18, torch.rand(2, 3, 224, 224))
test_flop_count_module(tm.resnet18)
test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {'inplace': True})

View File

@ -1,12 +1,13 @@
import pytest
import torch
import torch.distributed as dist
import torchvision.models as tm
from packaging import version
try:
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
except:
pass
from .zoo import tm_models, tmm_models
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor):
@ -28,7 +29,7 @@ def run_and_compare(model):
compare_all(x.grad, meta_x.grad)
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models + tmm_models)
def test_meta_mode_shape(m):
run_and_compare(m())

View File

@ -1,53 +0,0 @@
import timm.models as tmm
import torchvision.models as tm
# input shape: (batch_size, 3, 224, 224)
tm_models = [
tm.alexnet,
tm.convnext_base,
tm.densenet121,
# tm.efficientnet_v2_s,
# tm.googlenet, # output bad case
# tm.inception_v3, # bad case
tm.mobilenet_v2,
tm.mobilenet_v3_small,
tm.mnasnet0_5,
tm.resnet18,
tm.regnet_x_16gf,
tm.resnext50_32x4d,
tm.shufflenet_v2_x0_5,
tm.squeezenet1_0,
# tm.swin_s, # fx bad case
tm.vgg11,
tm.vit_b_16,
tm.wide_resnet50_2,
]
tmm_models = [
tmm.beit_base_patch16_224,
tmm.beitv2_base_patch16_224,
tmm.cait_s24_224,
tmm.coat_lite_mini,
tmm.convit_base,
tmm.deit3_base_patch16_224,
tmm.dm_nfnet_f0,
tmm.eca_nfnet_l0,
tmm.efficientformer_l1,
tmm.ese_vovnet19b_dw,
tmm.gmixer_12_224,
tmm.gmlp_b16_224,
tmm.hardcorenas_a,
tmm.hrnet_w18_small,
tmm.inception_v3,
tmm.mixer_b16_224,
tmm.nf_ecaresnet101,
tmm.nf_regnet_b0,
# tmm.pit_b_224, # pretrained only
tmm.regnetv_040,
tmm.skresnet18,
# tmm.swin_base_patch4_window7_224, # fx bad case
# tmm.tnt_b_patch16_224, # bad case
tmm.vgg11,
tmm.vit_base_patch16_18x2_224,
tmm.wide_resnet50_2,
]