mirror of https://github.com/hpcaitech/ColossalAI
[test] refactored with the new rerun decorator (#763)
* [test] refactored with the new rerun decorator * polish test casepull/773/head
parent
deaf99f4c9
commit
5a1a095b92
|
@ -3,7 +3,7 @@ import colossalai
|
|||
import torch.multiprocessing as mp
|
||||
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from colossalai.testing import assert_close_loose, rerun_on_exception
|
||||
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
|
||||
|
||||
|
@ -84,7 +84,7 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_naive_amp():
|
||||
world_size = 1
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
|
|
|
@ -9,7 +9,7 @@ from colossalai.context import ParallelMode
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
|
||||
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
|
||||
|
||||
|
@ -64,7 +64,7 @@ def check_layer(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_comm():
|
||||
world_size = 4
|
||||
run_func = partial(check_layer, world_size=world_size, port=free_port())
|
||||
|
|
|
@ -13,7 +13,7 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.utils import free_port
|
||||
from colossalai.context import reset_seeds
|
||||
from colossalai.global_variables import tensor_parallel_env as tp_env
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
|
||||
CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py'))
|
||||
|
||||
|
@ -141,7 +141,7 @@ def run_dist(rank, world_size, backend, port_list, host):
|
|||
|
||||
|
||||
@pytest.mark.cpu
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_context():
|
||||
"""
|
||||
As no computation or communication is done, we can run this test on CPU.
|
||||
|
|
|
@ -17,7 +17,7 @@ from torchvision import transforms
|
|||
from colossalai.context import ParallelMode, Config
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_dataloader, free_port
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
|
||||
CONFIG = Config(
|
||||
dict(
|
||||
|
@ -67,7 +67,7 @@ def run_data_sampler(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.cpu
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_data_sampler():
|
||||
world_size = 4
|
||||
test_func = partial(run_data_sampler, world_size=world_size, port=free_port())
|
||||
|
|
|
@ -17,7 +17,7 @@ from colossalai.builder import build_dataset, build_transform
|
|||
from colossalai.context import ParallelMode, Config
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
|
||||
CONFIG = Config(
|
||||
dict(
|
||||
|
@ -79,7 +79,7 @@ def run_data_sampler(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.cpu
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_data_sampler():
|
||||
world_size = 4
|
||||
test_func = partial(run_data_sampler, world_size=world_size, port=free_port())
|
||||
|
|
|
@ -15,7 +15,7 @@ from colossalai.nn.loss import CrossEntropyLoss
|
|||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.utils import free_port, get_dataloader
|
||||
from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from model_zoo.vit import vit_tiny_patch4_32
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
@ -23,9 +23,10 @@ from torchvision.datasets import CIFAR10
|
|||
BATCH_SIZE = 4
|
||||
NUM_EPOCHS = 60
|
||||
WARMUP_EPOCHS = 5
|
||||
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
|
||||
fp16=dict(mode=AMP_TYPE.NAIVE),
|
||||
gradient_accumulation=2)
|
||||
CONFIG = dict(NUM_MICRO_BATCHES=2,
|
||||
parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
|
||||
fp16=dict(mode=AMP_TYPE.NAIVE),
|
||||
gradient_accumulation=2)
|
||||
|
||||
|
||||
def run_trainer(rank, world_size, port):
|
||||
|
@ -79,7 +80,7 @@ def run_trainer(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_hybrid_parallel():
|
||||
world_size = 8
|
||||
run_func = partial(run_trainer, world_size=world_size, port=free_port())
|
||||
|
|
|
@ -7,7 +7,7 @@ from colossalai.amp import AMP_TYPE
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from colossalai.testing import parameterize, rerun_on_exception
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
|
||||
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
|
||||
fp16=dict(mode=None),
|
||||
|
@ -56,7 +56,7 @@ def run_engine(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_engine():
|
||||
world_size = 2
|
||||
run_func = partial(run_engine, world_size=world_size, port=free_port())
|
||||
|
|
|
@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from checks_1d.check_layer_1d import *
|
||||
|
||||
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),)
|
||||
|
@ -35,7 +35,7 @@ def check_layer(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_1d():
|
||||
world_size = 4
|
||||
run_func = partial(check_layer, world_size=world_size, port=free_port())
|
||||
|
|
|
@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
|
||||
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
|
||||
check_vocab_parallel_classifier_given_embed_weight,
|
||||
|
@ -55,7 +55,7 @@ def check_layer_and_operation(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_2d():
|
||||
world_size = 4
|
||||
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
|
||||
|
|
|
@ -7,7 +7,7 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from checks_2p5d.check_layer_2p5d import *
|
||||
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
|
||||
|
||||
|
@ -51,7 +51,7 @@ def check_layer_and_operation(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_2p5d():
|
||||
world_size = 4
|
||||
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
|
||||
|
|
|
@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
|
||||
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
|
||||
check_vocab_parallel_classifier_given_embed_weight,
|
||||
|
@ -51,7 +51,7 @@ def check_layer_and_operation(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_3d():
|
||||
world_size = 8
|
||||
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
|
||||
|
|
|
@ -7,7 +7,7 @@ import pytest
|
|||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from functools import partial
|
||||
|
||||
CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence')))
|
||||
|
@ -132,7 +132,7 @@ def run_test(rank, world_size):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_sequence():
|
||||
world_size = 4
|
||||
run_func = partial(run_test, world_size=world_size)
|
||||
|
|
|
@ -10,8 +10,7 @@ from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer,
|
|||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.utils.moe import sync_moe_model_param
|
||||
from colossalai.engine.gradient_handler import MoeGradientHandler
|
||||
from colossalai.testing import assert_equal_in_group
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use
|
||||
|
||||
BATCH_SIZE = 4
|
||||
DIM = 16
|
||||
|
@ -63,7 +62,7 @@ def run_test(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_grad_handler():
|
||||
world_size = 4
|
||||
run_func = partial(run_test, world_size=world_size, port=free_port())
|
||||
|
|
|
@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
|
||||
BATCH_SIZE = 16
|
||||
NUM_EXPERTS = 4
|
||||
|
@ -87,7 +87,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
|||
@pytest.mark.parametrize("hidden_size", [32, 144])
|
||||
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
|
||||
@pytest.mark.parametrize("router", [Top1Router, Top2Router])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_kernel(rs, hidden_size, data_type, router):
|
||||
world_size = 4
|
||||
run_func = partial(run_routing,
|
||||
|
|
|
@ -8,7 +8,7 @@ from colossalai.utils import free_port, get_current_device
|
|||
from colossalai.nn.layer.moe import Experts
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.utils.moe import sync_moe_model_param
|
||||
from colossalai.testing import assert_equal_in_group, rerun_on_exception
|
||||
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use
|
||||
|
||||
D_MODEL = 4
|
||||
D_FF = 8
|
||||
|
@ -60,7 +60,7 @@ def run_test(rank, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_initialization():
|
||||
world_size = 4
|
||||
run_func = partial(run_test, port=free_port())
|
||||
|
|
|
@ -14,7 +14,7 @@ from colossalai.nn.layer import MoeModule
|
|||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_zero.common import CONFIG
|
||||
|
||||
|
@ -91,7 +91,7 @@ def _run_dist(rank, world_size, port):
|
|||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_zero_init(world_size):
|
||||
run_func = partial(_run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
|
|
@ -4,7 +4,7 @@ import colossalai
|
|||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import parameterize, rerun_on_exception
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
|
@ -65,7 +65,7 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_zero_model(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
import torch.multiprocessing as mp
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.nn.optimizer import CPUAdam
|
||||
from colossalai.testing import parameterize, rerun_on_exception
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
|
@ -120,7 +120,7 @@ def _run_dist(rank, world_size, port):
|
|||
# use_cpuadam = True can be used with cpu_offload = False
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_zero_optim(world_size):
|
||||
run_func = partial(_run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
|
|
@ -9,7 +9,7 @@ from colossalai.logging import get_dist_logger
|
|||
from colossalai.trainer import Trainer
|
||||
from colossalai.utils import MultiTimer, free_port
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from colossalai.testing import parameterize, rerun_on_exception
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
|
||||
BATCH_SIZE = 4
|
||||
IMG_SIZE = 32
|
||||
|
@ -51,7 +51,7 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_trainer_no_pipeline():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
|
|
|
@ -17,13 +17,16 @@ from torch.optim import Adam
|
|||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.models import resnet18
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
|
||||
BATCH_SIZE = 4
|
||||
IMG_SIZE = 32
|
||||
NUM_EPOCHS = 200
|
||||
|
||||
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2),)
|
||||
CONFIG = dict(
|
||||
NUM_MICRO_BATCHES=2,
|
||||
parallel=dict(pipeline=2),
|
||||
)
|
||||
|
||||
|
||||
def run_trainer_with_pipeline(rank, world_size, port):
|
||||
|
@ -85,7 +88,7 @@ def run_trainer_with_pipeline(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_trainer_with_pipeline():
|
||||
world_size = 4
|
||||
run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port())
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
import colossalai
|
||||
|
||||
|
@ -35,7 +35,7 @@ def run_tensor_move(rank):
|
|||
assert (tgt_t.device.type == 'cpu')
|
||||
|
||||
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_tensor_move():
|
||||
mp.spawn(run_tensor_move, nprocs=1)
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from functools import partial
|
|||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing.utils import rerun_if_address_is_in_use
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
@ -10,7 +11,7 @@ import torch.nn as nn
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import free_port, get_dataloader
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from torch.optim import Adam
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
@ -87,7 +88,7 @@ def run_no_pipeline(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_engine():
|
||||
world_size = 4
|
||||
func = partial(run_no_pipeline, world_size=world_size, port=free_port())
|
||||
|
|
|
@ -16,7 +16,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||
from torch.nn.utils import clip_grad_norm_
|
||||
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
|
||||
from functools import partial
|
||||
from colossalai.testing import parameterize, rerun_on_exception
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
|
||||
|
||||
def checkpoint_wrapper(module, enable=True):
|
||||
|
@ -102,7 +102,7 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_clip_grad():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
|
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_on_exception
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import BucketTensorShardStrategy
|
||||
|
@ -62,7 +62,7 @@ def _run_dist(rank, world_size, port):
|
|||
# use_cpuadam = True can be used with cpu_offload = False
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_found_inf(world_size):
|
||||
run_func = partial(_run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
|
|
@ -8,7 +8,7 @@ import pytest
|
|||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.testing import parameterize, rerun_on_exception
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
|
@ -64,7 +64,7 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 4])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_init_context(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
|
|
@ -10,7 +10,7 @@ from colossalai.zero.init_ctx import ZeroInitContext
|
|||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.shard_utils import BucketTensorShardStrategy
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from functools import partial
|
||||
|
||||
|
||||
|
@ -64,7 +64,7 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mem_collector(world_size=2):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
|
|
@ -7,7 +7,7 @@ import colossalai
|
|||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import parameterize, rerun_on_exception
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
|
@ -59,7 +59,7 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_shard_model_v2(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
|
|
@ -5,12 +5,11 @@ import colossalai
|
|||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from tests.test_zero.common import CONFIG, allclose
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||
|
||||
|
@ -37,7 +36,7 @@ def _run_shard_tensor(rank, world_size, port):
|
|||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_shard_tensor(world_size):
|
||||
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
@ -85,7 +84,7 @@ def _run_shard_param_v2(rank, world_size, port):
|
|||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_shard_param_v2(world_size):
|
||||
run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch.distributed as dist
|
|||
import torch.multiprocessing as mp
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.nn.optimizer import CPUAdam
|
||||
from colossalai.testing import parameterize, rerun_on_exception
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
|
@ -105,7 +105,7 @@ def _run_dist(rank, world_size, port):
|
|||
# use_cpuadam = True can be used with cpu_offload = False
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_sharded_optim_v2(world_size):
|
||||
run_func = partial(_run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
|
|
@ -10,7 +10,7 @@ import torch.distributed as dist
|
|||
import torch.multiprocessing as mp
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
|
@ -71,7 +71,7 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_sharded_optim_with_sync_bn():
|
||||
"""
|
||||
This test is to make sure that buffers are synchronized between ranks
|
||||
|
|
|
@ -8,7 +8,7 @@ import colossalai
|
|||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import parameterize, rerun_on_exception
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
|
@ -49,7 +49,7 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_state_dict(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
|
|
@ -10,7 +10,7 @@ from colossalai.gemini import StatefulTensorMgr
|
|||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import List
|
||||
from functools import partial
|
||||
|
@ -120,8 +120,8 @@ def run_dist(rank, world_size, port):
|
|||
run_stm()
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_stateful_tensor_manager(world_size=1):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
|
|
@ -6,6 +6,7 @@ from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage
|
|||
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
|
||||
colo_model_tensor_clone)
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -84,6 +85,7 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4, 5])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_tensor_utils(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||
|
@ -96,7 +96,7 @@ def run_dist(rank, world_size, port, parallel_config):
|
|||
@pytest.mark.skip
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mp_engine(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
@ -104,7 +104,7 @@ def test_mp_engine(world_size):
|
|||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_engine(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
|
Loading…
Reference in New Issue