From 91cd34e6e0063a53a072a37e9864c5f14931fb52 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 14 Oct 2022 10:56:17 +0800 Subject: [PATCH] [unittest] added doc for the pytest wrapper (#1704) --- colossalai/testing/pytest_wrapper.py | 23 +++++++++++++++++++ .../test_shape_consistency_pass.py | 3 ++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/colossalai/testing/pytest_wrapper.py b/colossalai/testing/pytest_wrapper.py index eb6858892..a472eb372 100644 --- a/colossalai/testing/pytest_wrapper.py +++ b/colossalai/testing/pytest_wrapper.py @@ -1,3 +1,9 @@ +""" +This file will not be automatically imported by `colossalai.testing` +as this file has a dependency on `pytest`. Therefore, you need to +explicitly import this file `from colossalai.testing.pytest_wrapper import `.from +""" + import pytest import os @@ -6,6 +12,23 @@ def run_on_environment_flag(name: str): """ Conditionally run a test based on the environment variable. If this environment variable is set to 1, this test will be executed. Otherwise, this test is skipped. The environment variable is default to 0. + + Args: + name (str): the name of the environment variable flag. + + Usage: + # in your pytest file + @run_on_environment_flag(name='SOME_FLAG') + def test_for_something(): + do_something() + + # in your terminal + # this will execute your test + SOME_FLAG=1 pytest test_for_something.py + + # this will skip your test + pytest test_for_something.py + """ assert isinstance(name, str) flag = os.environ.get(name.upper(), '0') diff --git a/tests/test_auto_parallel/test_shape_consistency_pass.py b/tests/test_auto_parallel/test_shape_consistency_pass.py index 1b1ffbdf9..27a16a1cf 100644 --- a/tests/test_auto_parallel/test_shape_consistency_pass.py +++ b/tests/test_auto_parallel/test_shape_consistency_pass.py @@ -18,6 +18,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import shape_consistency_pass, solution_annotatation_pass from colossalai.auto_parallel.solver.solver import Solver_V2 from colossalai.auto_parallel.solver.options import SolverOptions +from colossalai.testing.pytest_wrapper import run_on_environment_flag class ConvModel(nn.Module): @@ -73,7 +74,7 @@ def check_apply(rank, world_size, port): assert output.equal(origin_output) -@pytest.mark.skip("for higher testing speed") +@run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.dist @rerun_if_address_is_in_use() def test_apply():