You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_analyzer/test_subclasses/test_flop_tensor.py

56 lines
1.8 KiB

import pytest
import torch
import torch.nn.functional as F
import torchvision.models as tm
from packaging import version
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
try:
from colossalai._analyzer._subclasses import MetaTensorMode, flop_count
except:
pass
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
@pytest.mark.parametrize("m", tm_models + tmm_models)
def test_flop_count_module(m):
x = torch.rand(2, 3, 224, 224)
with MetaTensorMode(): # save time for testing
module = m()
rs_fwd, rs_bwd = flop_count(module, x, verbose=True)
assert rs_fwd > 0, f"fwd flop count of {m.__name__} is {rs_fwd}"
assert rs_bwd > 0, f"bwd flop count of {m.__name__} is {rs_bwd}"
odd_cases = [
(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {"inplace": True}),
(
F.max_pool2d,
(torch.rand(2, 3, 224, 224, requires_grad=True),),
{"kernel_size": 3, "stride": 2, "padding": 1, "dilation": 2},
),
(
torch.where,
(
torch.rand(2, 3, 224, 224) > 0.5,
torch.rand(2, 3, 224, 224, requires_grad=True),
torch.rand(2, 3, 224, 224, requires_grad=True),
),
{},
),
]
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
@pytest.mark.parametrize("func, args, kwargs", odd_cases)
def test_flop_count_function(func, args, kwargs):
rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
assert rs_fwd > 0, f"fwd flop count of {func.__name__} is {rs_fwd}"
assert rs_bwd > 0, f"bwd flop count of {func.__name__} is {rs_bwd}"
if __name__ == "__main__":
test_flop_count_module(tm.resnet18)
test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {"inplace": True})