mirror of https://github.com/hpcaitech/ColossalAI
[unittest] added doc for the pytest wrapper (#1704)
parent
451cd72dea
commit
91cd34e6e0
|
@ -1,3 +1,9 @@
|
||||||
|
"""
|
||||||
|
This file will not be automatically imported by `colossalai.testing`
|
||||||
|
as this file has a dependency on `pytest`. Therefore, you need to
|
||||||
|
explicitly import this file `from colossalai.testing.pytest_wrapper import <func>`.from
|
||||||
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
@ -6,6 +12,23 @@ def run_on_environment_flag(name: str):
|
||||||
"""
|
"""
|
||||||
Conditionally run a test based on the environment variable. If this environment variable is set
|
Conditionally run a test based on the environment variable. If this environment variable is set
|
||||||
to 1, this test will be executed. Otherwise, this test is skipped. The environment variable is default to 0.
|
to 1, this test will be executed. Otherwise, this test is skipped. The environment variable is default to 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the environment variable flag.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# in your pytest file
|
||||||
|
@run_on_environment_flag(name='SOME_FLAG')
|
||||||
|
def test_for_something():
|
||||||
|
do_something()
|
||||||
|
|
||||||
|
# in your terminal
|
||||||
|
# this will execute your test
|
||||||
|
SOME_FLAG=1 pytest test_for_something.py
|
||||||
|
|
||||||
|
# this will skip your test
|
||||||
|
pytest test_for_something.py
|
||||||
|
|
||||||
"""
|
"""
|
||||||
assert isinstance(name, str)
|
assert isinstance(name, str)
|
||||||
flag = os.environ.get(name.upper(), '0')
|
flag = os.environ.get(name.upper(), '0')
|
||||||
|
|
|
@ -18,6 +18,7 @@ from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import shape_consistency_pass, solution_annotatation_pass
|
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import shape_consistency_pass, solution_annotatation_pass
|
||||||
from colossalai.auto_parallel.solver.solver import Solver_V2
|
from colossalai.auto_parallel.solver.solver import Solver_V2
|
||||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
|
||||||
|
|
||||||
class ConvModel(nn.Module):
|
class ConvModel(nn.Module):
|
||||||
|
@ -73,7 +74,7 @@ def check_apply(rank, world_size, port):
|
||||||
assert output.equal(origin_output)
|
assert output.equal(origin_output)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("for higher testing speed")
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_apply():
|
def test_apply():
|
||||||
|
|
Loading…
Reference in New Issue