import pytest import torch import torch.nn.functional as F import torchvision.models as tm from packaging import version from colossalai.testing import clear_cache_before_run, parameterize from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models try: from colossalai._analyzer._subclasses import MetaTensorMode, flop_count except: pass @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.parametrize('m', tm_models + tmm_models) def test_flop_count_module(m): x = torch.rand(2, 3, 224, 224) with MetaTensorMode(): # save time for testing module = m() rs_fwd, rs_bwd = flop_count(module, x, verbose=True) assert rs_fwd > 0, f'fwd flop count of {m.__name__} is {rs_fwd}' assert rs_bwd > 0, f'bwd flop count of {m.__name__} is {rs_bwd}' odd_cases = [ (F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), { 'inplace': True }), (F.max_pool2d, (torch.rand(2, 3, 224, 224, requires_grad=True),), { 'kernel_size': 3, 'stride': 2, 'padding': 1, 'dilation': 2 }), (torch.where, (torch.rand(2, 3, 224, 224) > 0.5, torch.rand(2, 3, 224, 224, requires_grad=True), torch.rand(2, 3, 224, 224, requires_grad=True)), {}), ] @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.parametrize('func, args, kwargs', odd_cases) def test_flop_count_function(func, args, kwargs): rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True) assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}' assert rs_bwd > 0, f'bwd flop count of {func.__name__} is {rs_bwd}' if __name__ == '__main__': test_flop_count_module(tm.resnet18) test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {'inplace': True})