mirror of https://github.com/hpcaitech/ColossalAI
parent
8283e95db3
commit
0e52f3d3d5
|
@ -0,0 +1,17 @@
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def run_on_environment_flag(name: str):
|
||||||
|
"""
|
||||||
|
Conditionally run a test based on the environment variable. If this environment variable is set
|
||||||
|
to 1, this test will be executed. Otherwise, this test is skipped. The environment variable is default to 0.
|
||||||
|
"""
|
||||||
|
assert isinstance(name, str)
|
||||||
|
flag = os.environ.get(name.upper(), '0')
|
||||||
|
|
||||||
|
reason = f'Environment varialbe {name} is {flag}'
|
||||||
|
if flag == '1':
|
||||||
|
return pytest.mark.skipif(False, reason=reason)
|
||||||
|
else:
|
||||||
|
return pytest.mark.skipif(True, reason=reason)
|
|
@ -193,11 +193,12 @@ def skip_if_not_enough_gpus(min_gpus: int):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _wrap_func(f):
|
def _wrap_func(f):
|
||||||
|
|
||||||
def _execute_by_gpu_num(*args, **kwargs):
|
def _execute_by_gpu_num(*args, **kwargs):
|
||||||
num_avail_gpu = torch.cuda.device_count()
|
num_avail_gpu = torch.cuda.device_count()
|
||||||
if num_avail_gpu >= min_gpus:
|
if num_avail_gpu >= min_gpus:
|
||||||
f(*args, **kwargs)
|
f(*args, **kwargs)
|
||||||
|
|
||||||
return _execute_by_gpu_num
|
return _execute_by_gpu_num
|
||||||
|
|
||||||
return _wrap_func
|
return _wrap_func
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptio
|
||||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
|
||||||
|
|
||||||
class ConvModel(nn.Module):
|
class ConvModel(nn.Module):
|
||||||
|
@ -22,7 +23,7 @@ class ConvModel(nn.Module):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("temporarily skipped")
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
def test_where_handler():
|
def test_where_handler():
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
mesh_shape = (2, 2)
|
mesh_shape = (2, 2)
|
||||||
|
|
|
@ -18,6 +18,7 @@ from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass import shape_consistency_pass, solution_annotatation_pass
|
from colossalai.fx.passes.experimental.adding_shape_consistency_pass import shape_consistency_pass, solution_annotatation_pass
|
||||||
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
|
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
|
||||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
|
||||||
|
|
||||||
class ConvModel(nn.Module):
|
class ConvModel(nn.Module):
|
||||||
|
@ -72,7 +73,7 @@ def check_apply(rank, world_size, port):
|
||||||
assert output.equal(origin_output)
|
assert output.equal(origin_output)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("for higher testing speed")
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_apply():
|
def test_apply():
|
||||||
|
|
|
@ -12,6 +12,7 @@ from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import Grap
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
|
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
|
||||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
|
||||||
|
|
||||||
class ConvModel(nn.Module):
|
class ConvModel(nn.Module):
|
||||||
|
@ -33,7 +34,7 @@ class ConvModel(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("for higher testing speed")
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
def test_solver():
|
def test_solver():
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
mesh_shape = (2, 2)
|
mesh_shape = (2, 2)
|
||||||
|
|
|
@ -15,12 +15,13 @@ import transformers
|
||||||
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
|
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
|
||||||
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
|
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
|
||||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
|
||||||
BATCH_SIZE = 8
|
BATCH_SIZE = 8
|
||||||
SEQ_LENGHT = 8
|
SEQ_LENGHT = 8
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("for higher testing speed")
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
def test_cost_graph():
|
def test_cost_graph():
|
||||||
physical_mesh_id = torch.arange(0, 8)
|
physical_mesh_id = torch.arange(0, 8)
|
||||||
mesh_shape = (2, 4)
|
mesh_shape = (2, 4)
|
||||||
|
|
|
@ -15,6 +15,7 @@ from torchvision.models import resnet34, resnet50
|
||||||
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
|
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
|
||||||
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
|
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
|
||||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
|
||||||
|
|
||||||
class MLP(torch.nn.Module):
|
class MLP(torch.nn.Module):
|
||||||
|
@ -34,7 +35,7 @@ class MLP(torch.nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("for higher testing speed")
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
def test_cost_graph():
|
def test_cost_graph():
|
||||||
physical_mesh_id = torch.arange(0, 8)
|
physical_mesh_id = torch.arange(0, 8)
|
||||||
mesh_shape = (2, 4)
|
mesh_shape = (2, 4)
|
||||||
|
|
|
@ -5,6 +5,7 @@ from colossalai.fx import ColoTracer, ColoGraphModule
|
||||||
from colossalai.auto_parallel.solver.node_handler.dot_handler import BMMFunctionHandler
|
from colossalai.auto_parallel.solver.node_handler.dot_handler import BMMFunctionHandler
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
|
||||||
|
|
||||||
class BMMTensorMethodModule(nn.Module):
|
class BMMTensorMethodModule(nn.Module):
|
||||||
|
@ -19,7 +20,7 @@ class BMMTorchFunctionModule(nn.Module):
|
||||||
return torch.bmm(x1, x2)
|
return torch.bmm(x1, x2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
|
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
|
||||||
def test_2d_device_mesh(module):
|
def test_2d_device_mesh(module):
|
||||||
|
|
||||||
|
@ -90,7 +91,7 @@ def test_2d_device_mesh(module):
|
||||||
assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list
|
assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
|
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
|
||||||
def test_1d_device_mesh(module):
|
def test_1d_device_mesh(module):
|
||||||
model = module()
|
model = module()
|
||||||
|
|
|
@ -6,9 +6,10 @@ from colossalai.auto_parallel.solver.node_handler.normal_pooling_handler import
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
import pytest
|
import pytest
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("for higher testing speed")
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
def test_norm_pool_handler():
|
def test_norm_pool_handler():
|
||||||
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
|
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
|
|
|
@ -15,9 +15,10 @@ from torchvision.models import resnet34, resnet50
|
||||||
from colossalai.auto_parallel.solver.constants import *
|
from colossalai.auto_parallel.solver.constants import *
|
||||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("for higher testing speed")
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
def test_cost_graph():
|
def test_cost_graph():
|
||||||
physical_mesh_id = torch.arange(0, 8)
|
physical_mesh_id = torch.arange(0, 8)
|
||||||
mesh_shape = (2, 4)
|
mesh_shape = (2, 4)
|
||||||
|
|
Loading…
Reference in New Issue