[unittest] supported condititonal testing based on env var (#1701)

polish code
pull/1703/head
Frank Lee 2 years ago committed by GitHub
parent 8283e95db3
commit 0e52f3d3d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,17 @@
import pytest
import os
def run_on_environment_flag(name: str):
"""
Conditionally run a test based on the environment variable. If this environment variable is set
to 1, this test will be executed. Otherwise, this test is skipped. The environment variable is default to 0.
"""
assert isinstance(name, str)
flag = os.environ.get(name.upper(), '0')
reason = f'Environment varialbe {name} is {flag}'
if flag == '1':
return pytest.mark.skipif(False, reason=reason)
else:
return pytest.mark.skipif(True, reason=reason)

@ -193,11 +193,12 @@ def skip_if_not_enough_gpus(min_gpus: int):
"""
def _wrap_func(f):
def _execute_by_gpu_num(*args, **kwargs):
num_avail_gpu = torch.cuda.device_count()
if num_avail_gpu >= min_gpus:
f(*args, **kwargs)
return _execute_by_gpu_num
return _wrap_func

@ -7,6 +7,7 @@ from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptio
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class ConvModel(nn.Module):
@ -22,7 +23,7 @@ class ConvModel(nn.Module):
return output
@pytest.mark.skip("temporarily skipped")
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_where_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)

@ -18,6 +18,7 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.passes.experimental.adding_shape_consistency_pass import shape_consistency_pass, solution_annotatation_pass
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class ConvModel(nn.Module):
@ -72,7 +73,7 @@ def check_apply(rank, world_size, port):
assert output.equal(origin_output)
@pytest.mark.skip("for higher testing speed")
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_apply():

@ -12,6 +12,7 @@ from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import Grap
from copy import deepcopy
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class ConvModel(nn.Module):
@ -33,7 +34,7 @@ class ConvModel(nn.Module):
return x
@pytest.mark.skip("for higher testing speed")
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_solver():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)

@ -15,12 +15,13 @@ import transformers
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag
BATCH_SIZE = 8
SEQ_LENGHT = 8
@pytest.mark.skip("for higher testing speed")
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_cost_graph():
physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4)

@ -15,6 +15,7 @@ from torchvision.models import resnet34, resnet50
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class MLP(torch.nn.Module):
@ -34,7 +35,7 @@ class MLP(torch.nn.Module):
return x
@pytest.mark.skip("for higher testing speed")
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_cost_graph():
physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4)

@ -5,6 +5,7 @@ from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.dot_handler import BMMFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class BMMTensorMethodModule(nn.Module):
@ -19,7 +20,7 @@ class BMMTorchFunctionModule(nn.Module):
return torch.bmm(x1, x2)
@pytest.mark.skip
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_2d_device_mesh(module):
@ -90,7 +91,7 @@ def test_2d_device_mesh(module):
assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list
@pytest.mark.skip
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_1d_device_mesh(module):
model = module()

@ -6,9 +6,10 @@ from colossalai.auto_parallel.solver.node_handler.normal_pooling_handler import
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
import pytest
from colossalai.testing.pytest_wrapper import run_on_environment_flag
@pytest.mark.skip("for higher testing speed")
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_norm_pool_handler():
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
tracer = ColoTracer()

@ -15,9 +15,10 @@ from torchvision.models import resnet34, resnet50
from colossalai.auto_parallel.solver.constants import *
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.solver.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag
@pytest.mark.skip("for higher testing speed")
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_cost_graph():
physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4)

Loading…
Cancel
Save