From 0e52f3d3d5c659d71b4457a59cc8098f699e8f35 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Thu, 13 Oct 2022 19:38:45 +0800 Subject: [PATCH] [unittest] supported condititonal testing based on env var (#1701) polish code --- colossalai/testing/pytest_wrapper.py | 17 +++++++++++++++++ colossalai/testing/utils.py | 3 ++- .../test_deprecated_where_handler.py | 3 ++- .../test_deprecated_shape_consistency_pass.py | 3 ++- .../test_deprecated/test_deprecated_solver.py | 3 ++- .../test_deprecated_solver_with_gpt.py | 3 ++- .../test_deprecated_solver_with_mlp.py | 3 ++- .../test_node_handler/test_bmm_handler.py | 5 +++-- .../test_norm_pooling_handler.py | 3 ++- .../test_solver_with_resnet_v2.py | 3 ++- 10 files changed, 36 insertions(+), 10 deletions(-) create mode 100644 colossalai/testing/pytest_wrapper.py diff --git a/colossalai/testing/pytest_wrapper.py b/colossalai/testing/pytest_wrapper.py new file mode 100644 index 000000000..eb6858892 --- /dev/null +++ b/colossalai/testing/pytest_wrapper.py @@ -0,0 +1,17 @@ +import pytest +import os + + +def run_on_environment_flag(name: str): + """ + Conditionally run a test based on the environment variable. If this environment variable is set + to 1, this test will be executed. Otherwise, this test is skipped. The environment variable is default to 0. + """ + assert isinstance(name, str) + flag = os.environ.get(name.upper(), '0') + + reason = f'Environment varialbe {name} is {flag}' + if flag == '1': + return pytest.mark.skipif(False, reason=reason) + else: + return pytest.mark.skipif(True, reason=reason) diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py index 4f0c2beee..64c1d6e7b 100644 --- a/colossalai/testing/utils.py +++ b/colossalai/testing/utils.py @@ -193,11 +193,12 @@ def skip_if_not_enough_gpus(min_gpus: int): """ def _wrap_func(f): + def _execute_by_gpu_num(*args, **kwargs): num_avail_gpu = torch.cuda.device_count() if num_avail_gpu >= min_gpus: f(*args, **kwargs) + return _execute_by_gpu_num return _wrap_func - diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py index 1fd8fff2d..294a59fc8 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py @@ -7,6 +7,7 @@ from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptio from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor from colossalai.fx.tracer.tracer import ColoTracer from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing.pytest_wrapper import run_on_environment_flag class ConvModel(nn.Module): @@ -22,7 +23,7 @@ class ConvModel(nn.Module): return output -@pytest.mark.skip("temporarily skipped") +@run_on_environment_flag(name='AUTO_PARALLEL') def test_where_handler(): physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py index b15497d2c..3286b325c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py @@ -18,6 +18,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.passes.experimental.adding_shape_consistency_pass import shape_consistency_pass, solution_annotatation_pass from colossalai.auto_parallel.tensor_shard.deprecated import Solver from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.testing.pytest_wrapper import run_on_environment_flag class ConvModel(nn.Module): @@ -72,7 +73,7 @@ def check_apply(rank, world_size, port): assert output.equal(origin_output) -@pytest.mark.skip("for higher testing speed") +@run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.dist @rerun_if_address_is_in_use() def test_apply(): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py index df640050a..65bbd6bc3 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py @@ -12,6 +12,7 @@ from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import Grap from copy import deepcopy from colossalai.auto_parallel.tensor_shard.deprecated import Solver from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.testing.pytest_wrapper import run_on_environment_flag class ConvModel(nn.Module): @@ -33,7 +34,7 @@ class ConvModel(nn.Module): return x -@pytest.mark.skip("for higher testing speed") +@run_on_environment_flag(name='AUTO_PARALLEL') def test_solver(): physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py index ac0ce1b87..e90d6b153 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py @@ -15,12 +15,13 @@ import transformers from colossalai.auto_parallel.tensor_shard.deprecated.constants import * from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.testing.pytest_wrapper import run_on_environment_flag BATCH_SIZE = 8 SEQ_LENGHT = 8 -@pytest.mark.skip("for higher testing speed") +@run_on_environment_flag(name='AUTO_PARALLEL') def test_cost_graph(): physical_mesh_id = torch.arange(0, 8) mesh_shape = (2, 4) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py index 7ba63951d..415156ed6 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py @@ -15,6 +15,7 @@ from torchvision.models import resnet34, resnet50 from colossalai.auto_parallel.tensor_shard.deprecated.constants import * from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.testing.pytest_wrapper import run_on_environment_flag class MLP(torch.nn.Module): @@ -34,7 +35,7 @@ class MLP(torch.nn.Module): return x -@pytest.mark.skip("for higher testing speed") +@run_on_environment_flag(name='AUTO_PARALLEL') def test_cost_graph(): physical_mesh_id = torch.arange(0, 8) mesh_shape = (2, 4) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py index ad45ee3f1..9ec536743 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py @@ -5,6 +5,7 @@ from colossalai.fx import ColoTracer, ColoGraphModule from colossalai.auto_parallel.solver.node_handler.dot_handler import BMMFunctionHandler from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing.pytest_wrapper import run_on_environment_flag class BMMTensorMethodModule(nn.Module): @@ -19,7 +20,7 @@ class BMMTorchFunctionModule(nn.Module): return torch.bmm(x1, x2) -@pytest.mark.skip +@run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) def test_2d_device_mesh(module): @@ -90,7 +91,7 @@ def test_2d_device_mesh(module): assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list -@pytest.mark.skip +@run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) def test_1d_device_mesh(module): model = module() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py index 63ca627d4..423940558 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py @@ -6,9 +6,10 @@ from colossalai.auto_parallel.solver.node_handler.normal_pooling_handler import from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh import pytest +from colossalai.testing.pytest_wrapper import run_on_environment_flag -@pytest.mark.skip("for higher testing speed") +@run_on_environment_flag(name='AUTO_PARALLEL') def test_norm_pool_handler(): model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta')) tracer = ColoTracer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py index a8e90ba0b..a75337f10 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py @@ -15,9 +15,10 @@ from torchvision.models import resnet34, resnet50 from colossalai.auto_parallel.solver.constants import * from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser from colossalai.auto_parallel.solver.options import SolverOptions +from colossalai.testing.pytest_wrapper import run_on_environment_flag -@pytest.mark.skip("for higher testing speed") +@run_on_environment_flag(name='AUTO_PARALLEL') def test_cost_graph(): physical_mesh_id = torch.arange(0, 8) mesh_shape = (2, 4)