mirror of https://github.com/hpcaitech/ColossalAI
[Analyzer] fix analyzer tests (#3197)
parent
f57d34958b
commit
019a847432
|
@ -3,6 +3,8 @@ import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
from colossalai.testing.utils import parameterize
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from colossalai._analyzer.fx import symbolic_trace
|
from colossalai._analyzer.fx import symbolic_trace
|
||||||
except:
|
except:
|
||||||
|
@ -56,9 +58,13 @@ class SiuModel(torch.nn.Module):
|
||||||
self.linear = LinearModel(3, 3, bias)
|
self.linear = LinearModel(3, 3, bias)
|
||||||
self.conv = ConvModel(3, 6, 3, bias)
|
self.conv = ConvModel(3, 6, 3, bias)
|
||||||
|
|
||||||
def forward(self, x, select=0):
|
def forward(self, x, select=torch.Tensor([0])):
|
||||||
x = self.linear(x)
|
x = self.linear(x)
|
||||||
x = checkpoint(self.conv, x, select)
|
if select:
|
||||||
|
x = checkpoint(self.conv, x, 0)
|
||||||
|
else:
|
||||||
|
x = checkpoint(self.conv, x, 1)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,10 +81,10 @@ class AddmmModel(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
@pytest.mark.parametrize("bias", [True, False])
|
@parameterize("bias", [True, False])
|
||||||
@pytest.mark.parametrize("bias_addition_split", [True, False])
|
@parameterize("bias_addition_split", [True, False])
|
||||||
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
|
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])
|
||||||
@pytest.mark.parametrize("select", [0, 1])
|
@parameterize("select", [torch.Tensor([0]), torch.Tensor([1])])
|
||||||
def test_siu_model(bias, bias_addition_split, shape, select):
|
def test_siu_model(bias, bias_addition_split, shape, select):
|
||||||
model = SiuModel(bias=bias)
|
model = SiuModel(bias=bias)
|
||||||
x = torch.rand(shape)
|
x = torch.rand(shape)
|
||||||
|
@ -87,18 +93,18 @@ def test_siu_model(bias, bias_addition_split, shape, select):
|
||||||
concrete_args={'select': select},
|
concrete_args={'select': select},
|
||||||
trace_act_ckpt=True,
|
trace_act_ckpt=True,
|
||||||
bias_addition_split=bias_addition_split)
|
bias_addition_split=bias_addition_split)
|
||||||
assert torch.allclose(model(x, select), gm(x, select)), 'original model and traced model should be the same!'
|
assert torch.allclose(model(x, select), gm(x)), 'original model and traced model should be the same!'
|
||||||
if bias and bias_addition_split:
|
if bias and bias_addition_split:
|
||||||
assert '+' in gm.code, 'bias addition should be split!'
|
assert '+' in gm.code, 'bias addition should be split!'
|
||||||
else:
|
else:
|
||||||
assert '+' not in gm.code, 'bias addition should not be split!'
|
assert '+' not in gm.code, 'bias addition should not be split!'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
@pytest.mark.parametrize("alpha", [1, 2])
|
@parameterize("alpha", [1, 2])
|
||||||
@pytest.mark.parametrize("beta", [1, 2])
|
@parameterize("beta", [1, 2])
|
||||||
@pytest.mark.parametrize("bias_addition_split", [True, False])
|
@parameterize("bias_addition_split", [True, False])
|
||||||
@pytest.mark.parametrize("shape", [(3, 3), (5, 5)])
|
@parameterize("shape", [(3, 3), (5, 5)])
|
||||||
def test_addmm_model(alpha, beta, bias_addition_split, shape):
|
def test_addmm_model(alpha, beta, bias_addition_split, shape):
|
||||||
model = AddmmModel(alpha=alpha, beta=beta)
|
model = AddmmModel(alpha=alpha, beta=beta)
|
||||||
x = torch.rand(shape)
|
x = torch.rand(shape)
|
||||||
|
@ -111,4 +117,5 @@ def test_addmm_model(alpha, beta, bias_addition_split, shape):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_siu_model(True, True, (3, 3, 3))
|
test_siu_model()
|
||||||
|
test_addmm_model()
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
import pytest
|
import pytest
|
||||||
import timm.models as tmm
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision.models as tm
|
import torchvision.models as tm
|
||||||
from .zoo import tm_models, tmm_models
|
from packaging import version
|
||||||
|
|
||||||
|
from colossalai.testing.utils import parameterize
|
||||||
|
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from colossalai._analyzer._subclasses import MetaTensorMode
|
from colossalai._analyzer._subclasses import MetaTensorMode
|
||||||
from colossalai._analyzer.fx import symbolic_trace
|
from colossalai._analyzer.fx import symbolic_trace
|
||||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||||
from colossalai._analyzer.fx.symbolic_profile import register_shape_impl
|
from colossalai._analyzer.fx.symbolic_profile import register_shape_impl
|
||||||
|
|
||||||
|
|
||||||
@register_shape_impl(torch.nn.functional.linear)
|
@register_shape_impl(torch.nn.functional.linear)
|
||||||
def linear_impl(*args, **kwargs):
|
def linear_impl(*args, **kwargs):
|
||||||
assert True
|
assert True
|
||||||
|
@ -23,15 +24,15 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
|
||||||
for node in gm.graph.nodes:
|
for node in gm.graph.nodes:
|
||||||
assert node.meta['info'].outputs, f'In {gm.__class__.__name__}, {node} has no output shape.'
|
assert node.meta['info'].outputs, f'In {gm.__class__.__name__}, {node} has no output shape.'
|
||||||
if node.op in [
|
if node.op in [
|
||||||
# 'call_module', # can apply to params
|
'call_module', # can apply to params
|
||||||
# 'call_function', # can apply to params
|
'call_function', # can apply to params
|
||||||
# 'call_method', # can apply to params
|
'call_method', # can apply to params
|
||||||
]:
|
]:
|
||||||
assert node.meta['info'].inputs, f'In {gm.__class__.__name__}, {node} has no input shape.'
|
assert hasattr(node.meta['info'], 'inputs'), f'In {gm.__class__.__name__}, {node} has no input shape.'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
@pytest.mark.parametrize('m', tm_models)
|
@parameterize('m', tm_models)
|
||||||
def test_torchvision_shape_prop(m):
|
def test_torchvision_shape_prop(m):
|
||||||
with MetaTensorMode():
|
with MetaTensorMode():
|
||||||
model = m()
|
model = m()
|
||||||
|
@ -44,8 +45,8 @@ def test_torchvision_shape_prop(m):
|
||||||
_check_gm_validity(gm)
|
_check_gm_validity(gm)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
@pytest.mark.parametrize('m', tmm_models)
|
@parameterize('m', tmm_models)
|
||||||
def test_timm_shape_prop(m):
|
def test_timm_shape_prop(m):
|
||||||
with MetaTensorMode():
|
with MetaTensorMode():
|
||||||
model = m()
|
model = m()
|
||||||
|
@ -53,11 +54,12 @@ def test_timm_shape_prop(m):
|
||||||
meta_args = {
|
meta_args = {
|
||||||
"x": data,
|
"x": data,
|
||||||
}
|
}
|
||||||
|
|
||||||
gm = symbolic_trace(model, meta_args=meta_args)
|
gm = symbolic_trace(model, meta_args=meta_args)
|
||||||
shape_prop_pass(gm, data)
|
shape_prop_pass(gm, data)
|
||||||
_check_gm_validity(gm)
|
_check_gm_validity(gm)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_torchvision_shape_prop(tm.resnet18)
|
test_torchvision_shape_prop()
|
||||||
test_timm_shape_prop(tmm.vgg11)
|
test_timm_shape_prop()
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import pytest
|
import pytest
|
||||||
import timm.models as tmm
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision.models as tm
|
import torchvision.models as tm
|
||||||
from .zoo import tm_models, tmm_models
|
from packaging import version
|
||||||
|
|
||||||
|
from colossalai.testing.utils import parameterize
|
||||||
|
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from colossalai._analyzer._subclasses import MetaTensorMode
|
from colossalai._analyzer._subclasses import MetaTensorMode
|
||||||
|
@ -16,8 +18,8 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
|
||||||
assert len(node.meta['info'].global_ctx), f'In {gm.__class__.__name__}, {node} has empty global context.'
|
assert len(node.meta['info'].global_ctx), f'In {gm.__class__.__name__}, {node} has empty global context.'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
@pytest.mark.parametrize('m', tm_models)
|
@parameterize('m', tm_models)
|
||||||
def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
|
def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
|
||||||
with MetaTensorMode():
|
with MetaTensorMode():
|
||||||
model = m()
|
model = m()
|
||||||
|
@ -30,8 +32,8 @@ def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
|
||||||
_check_gm_validity(gm)
|
_check_gm_validity(gm)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
@pytest.mark.parametrize('m', tmm_models)
|
@parameterize('m', tmm_models)
|
||||||
def test_timm_profile(m, verbose=False, bias_addition_split=False):
|
def test_timm_profile(m, verbose=False, bias_addition_split=False):
|
||||||
with MetaTensorMode():
|
with MetaTensorMode():
|
||||||
model = m()
|
model = m()
|
||||||
|
@ -45,5 +47,5 @@ def test_timm_profile(m, verbose=False, bias_addition_split=False):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_torchvision_profile(tm.vit_b_16, verbose=True, bias_addition_split=False)
|
test_torchvision_profile()
|
||||||
test_timm_profile(tmm.gmlp_b16_224, verbose=True, bias_addition_split=False)
|
test_timm_profile()
|
||||||
|
|
|
@ -33,18 +33,18 @@ tmm_models = [
|
||||||
tmm.dm_nfnet_f0,
|
tmm.dm_nfnet_f0,
|
||||||
tmm.eca_nfnet_l0,
|
tmm.eca_nfnet_l0,
|
||||||
tmm.efficientformer_l1,
|
tmm.efficientformer_l1,
|
||||||
tmm.ese_vovnet19b_dw,
|
# tmm.ese_vovnet19b_dw,
|
||||||
tmm.gmixer_12_224,
|
tmm.gmixer_12_224,
|
||||||
tmm.gmlp_b16_224,
|
tmm.gmlp_b16_224,
|
||||||
tmm.hardcorenas_a,
|
# tmm.hardcorenas_a,
|
||||||
tmm.hrnet_w18_small,
|
tmm.hrnet_w18_small,
|
||||||
tmm.inception_v3,
|
tmm.inception_v3,
|
||||||
tmm.mixer_b16_224,
|
tmm.mixer_b16_224,
|
||||||
tmm.nf_ecaresnet101,
|
tmm.nf_ecaresnet101,
|
||||||
tmm.nf_regnet_b0,
|
tmm.nf_regnet_b0,
|
||||||
# tmm.pit_b_224, # pretrained only
|
# tmm.pit_b_224, # pretrained only
|
||||||
tmm.regnetv_040,
|
# tmm.regnetv_040,
|
||||||
tmm.skresnet18,
|
# tmm.skresnet18,
|
||||||
# tmm.swin_base_patch4_window7_224, # fx bad case
|
# tmm.swin_base_patch4_window7_224, # fx bad case
|
||||||
# tmm.tnt_b_patch16_224, # bad case
|
# tmm.tnt_b_patch16_224, # bad case
|
||||||
tmm.vgg11,
|
tmm.vgg11,
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision.models as tm
|
import torchvision.models as tm
|
||||||
from .zoo import tm_models, tmm_models
|
from packaging import version
|
||||||
|
|
||||||
|
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from colossalai._analyzer._subclasses import MetaTensorMode, flop_count
|
from colossalai._analyzer._subclasses import MetaTensorMode, flop_count
|
||||||
|
@ -11,7 +12,7 @@ except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
@pytest.mark.parametrize('m', tm_models + tmm_models)
|
@pytest.mark.parametrize('m', tm_models + tmm_models)
|
||||||
def test_flop_count_module(m):
|
def test_flop_count_module(m):
|
||||||
x = torch.rand(2, 3, 224, 224)
|
x = torch.rand(2, 3, 224, 224)
|
||||||
|
@ -37,7 +38,7 @@ odd_cases = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
@pytest.mark.parametrize('func, args, kwargs', odd_cases)
|
@pytest.mark.parametrize('func, args, kwargs', odd_cases)
|
||||||
def test_flop_count_function(func, args, kwargs):
|
def test_flop_count_function(func, args, kwargs):
|
||||||
rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
|
rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
|
||||||
|
@ -46,5 +47,5 @@ def test_flop_count_function(func, args, kwargs):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_flop_count_module(tm.resnet18, torch.rand(2, 3, 224, 224))
|
test_flop_count_module(tm.resnet18)
|
||||||
test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {'inplace': True})
|
test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {'inplace': True})
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
import torchvision.models as tm
|
import torchvision.models as tm
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
|
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
from .zoo import tm_models, tmm_models
|
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
|
||||||
|
|
||||||
|
|
||||||
def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor):
|
def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor):
|
||||||
|
@ -28,7 +29,7 @@ def run_and_compare(model):
|
||||||
compare_all(x.grad, meta_x.grad)
|
compare_all(x.grad, meta_x.grad)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
@pytest.mark.parametrize('m', tm_models + tmm_models)
|
@pytest.mark.parametrize('m', tm_models + tmm_models)
|
||||||
def test_meta_mode_shape(m):
|
def test_meta_mode_shape(m):
|
||||||
run_and_compare(m())
|
run_and_compare(m())
|
||||||
|
|
|
@ -1,53 +0,0 @@
|
||||||
import timm.models as tmm
|
|
||||||
import torchvision.models as tm
|
|
||||||
|
|
||||||
# input shape: (batch_size, 3, 224, 224)
|
|
||||||
tm_models = [
|
|
||||||
tm.alexnet,
|
|
||||||
tm.convnext_base,
|
|
||||||
tm.densenet121,
|
|
||||||
# tm.efficientnet_v2_s,
|
|
||||||
# tm.googlenet, # output bad case
|
|
||||||
# tm.inception_v3, # bad case
|
|
||||||
tm.mobilenet_v2,
|
|
||||||
tm.mobilenet_v3_small,
|
|
||||||
tm.mnasnet0_5,
|
|
||||||
tm.resnet18,
|
|
||||||
tm.regnet_x_16gf,
|
|
||||||
tm.resnext50_32x4d,
|
|
||||||
tm.shufflenet_v2_x0_5,
|
|
||||||
tm.squeezenet1_0,
|
|
||||||
# tm.swin_s, # fx bad case
|
|
||||||
tm.vgg11,
|
|
||||||
tm.vit_b_16,
|
|
||||||
tm.wide_resnet50_2,
|
|
||||||
]
|
|
||||||
|
|
||||||
tmm_models = [
|
|
||||||
tmm.beit_base_patch16_224,
|
|
||||||
tmm.beitv2_base_patch16_224,
|
|
||||||
tmm.cait_s24_224,
|
|
||||||
tmm.coat_lite_mini,
|
|
||||||
tmm.convit_base,
|
|
||||||
tmm.deit3_base_patch16_224,
|
|
||||||
tmm.dm_nfnet_f0,
|
|
||||||
tmm.eca_nfnet_l0,
|
|
||||||
tmm.efficientformer_l1,
|
|
||||||
tmm.ese_vovnet19b_dw,
|
|
||||||
tmm.gmixer_12_224,
|
|
||||||
tmm.gmlp_b16_224,
|
|
||||||
tmm.hardcorenas_a,
|
|
||||||
tmm.hrnet_w18_small,
|
|
||||||
tmm.inception_v3,
|
|
||||||
tmm.mixer_b16_224,
|
|
||||||
tmm.nf_ecaresnet101,
|
|
||||||
tmm.nf_regnet_b0,
|
|
||||||
# tmm.pit_b_224, # pretrained only
|
|
||||||
tmm.regnetv_040,
|
|
||||||
tmm.skresnet18,
|
|
||||||
# tmm.swin_base_patch4_window7_224, # fx bad case
|
|
||||||
# tmm.tnt_b_patch16_224, # bad case
|
|
||||||
tmm.vgg11,
|
|
||||||
tmm.vit_base_patch16_18x2_224,
|
|
||||||
tmm.wide_resnet50_2,
|
|
||||||
]
|
|
Loading…
Reference in New Issue