[test] refactored with the new rerun decorator (#763)

* [test] refactored with the new rerun decorator

* polish test case
pull/773/head
Frank Lee 2022-04-15 00:33:04 +08:00 committed by GitHub
parent deaf99f4c9
commit 5a1a095b92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 80 additions and 75 deletions

View File

@ -3,7 +3,7 @@ import colossalai
import torch.multiprocessing as mp
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import assert_close_loose, rerun_on_exception
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
@ -84,7 +84,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_naive_amp():
world_size = 1
run_func = partial(run_dist, world_size=world_size, port=free_port())

View File

@ -9,7 +9,7 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
@ -64,7 +64,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_comm():
world_size = 4
run_func = partial(check_layer, world_size=world_size, port=free_port())

View File

@ -13,7 +13,7 @@ from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from colossalai.context import reset_seeds
from colossalai.global_variables import tensor_parallel_env as tp_env
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py'))
@ -141,7 +141,7 @@ def run_dist(rank, world_size, backend, port_list, host):
@pytest.mark.cpu
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_context():
"""
As no computation or communication is done, we can run this test on CPU.

View File

@ -17,7 +17,7 @@ from torchvision import transforms
from colossalai.context import ParallelMode, Config
from colossalai.core import global_context as gpc
from colossalai.utils import get_dataloader, free_port
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
CONFIG = Config(
dict(
@ -67,7 +67,7 @@ def run_data_sampler(rank, world_size, port):
@pytest.mark.cpu
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_data_sampler():
world_size = 4
test_func = partial(run_data_sampler, world_size=world_size, port=free_port())

View File

@ -17,7 +17,7 @@ from colossalai.builder import build_dataset, build_transform
from colossalai.context import ParallelMode, Config
from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
CONFIG = Config(
dict(
@ -79,7 +79,7 @@ def run_data_sampler(rank, world_size, port):
@pytest.mark.cpu
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_data_sampler():
world_size = 4
test_func = partial(run_data_sampler, world_size=world_size, port=free_port())

View File

@ -15,7 +15,7 @@ from colossalai.nn.loss import CrossEntropyLoss
from colossalai.trainer import Trainer, hooks
from colossalai.utils import free_port, get_dataloader
from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from model_zoo.vit import vit_tiny_patch4_32
from torchvision import transforms
from torchvision.datasets import CIFAR10
@ -23,7 +23,8 @@ from torchvision.datasets import CIFAR10
BATCH_SIZE = 4
NUM_EPOCHS = 60
WARMUP_EPOCHS = 5
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
CONFIG = dict(NUM_MICRO_BATCHES=2,
parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
fp16=dict(mode=AMP_TYPE.NAIVE),
gradient_accumulation=2)
@ -79,7 +80,7 @@ def run_trainer(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_hybrid_parallel():
world_size = 8
run_func = partial(run_trainer, world_size=world_size, port=free_port())

View File

@ -7,7 +7,7 @@ from colossalai.amp import AMP_TYPE
from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
fp16=dict(mode=None),
@ -56,7 +56,7 @@ def run_engine(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_engine():
world_size = 2
run_func = partial(run_engine, world_size=world_size, port=free_port())

View File

@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from checks_1d.check_layer_1d import *
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),)
@ -35,7 +35,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_1d():
world_size = 4
run_func = partial(check_layer, world_size=world_size, port=free_port())

View File

@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight,
@ -55,7 +55,7 @@ def check_layer_and_operation(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_2d():
world_size = 4
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())

View File

@ -7,7 +7,7 @@ from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from checks_2p5d.check_layer_2p5d import *
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
@ -51,7 +51,7 @@ def check_layer_and_operation(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_2p5d():
world_size = 4
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())

View File

@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight,
@ -51,7 +51,7 @@ def check_layer_and_operation(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_3d():
world_size = 8
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())

View File

@ -7,7 +7,7 @@ import pytest
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from functools import partial
CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence')))
@ -132,7 +132,7 @@ def run_test(rank, world_size):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_sequence():
world_size = 4
run_func = partial(run_test, world_size=world_size)

View File

@ -10,8 +10,7 @@ from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer,
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param
from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.testing import assert_equal_in_group
from colossalai.testing import rerun_on_exception
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use
BATCH_SIZE = 4
DIM = 16
@ -63,7 +62,7 @@ def run_test(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_grad_handler():
world_size = 4
run_func = partial(run_test, world_size=world_size, port=free_port())

View File

@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
BATCH_SIZE = 16
NUM_EXPERTS = 4
@ -87,7 +87,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
@pytest.mark.parametrize("hidden_size", [32, 144])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
@pytest.mark.parametrize("router", [Top1Router, Top2Router])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_moe_kernel(rs, hidden_size, data_type, router):
world_size = 4
run_func = partial(run_routing,

View File

@ -8,7 +8,7 @@ from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Experts
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param
from colossalai.testing import assert_equal_in_group, rerun_on_exception
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use
D_MODEL = 4
D_FF = 8
@ -60,7 +60,7 @@ def run_test(rank, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_moe_initialization():
world_size = 4
run_func = partial(run_test, port=free_port())

View File

@ -14,7 +14,7 @@ from colossalai.nn.layer import MoeModule
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import get_current_device
from tests.test_zero.common import CONFIG
@ -91,7 +91,7 @@ def _run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_moe_zero_init(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -4,7 +4,7 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
@ -65,7 +65,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_moe_zero_model(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -6,7 +6,7 @@ import torch
import torch.multiprocessing as mp
from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
@ -120,7 +120,7 @@ def _run_dist(rank, world_size, port):
# use_cpuadam = True can be used with cpu_offload = False
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_moe_zero_optim(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -9,7 +9,7 @@ from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer
from colossalai.utils import MultiTimer, free_port
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
BATCH_SIZE = 4
IMG_SIZE = 32
@ -51,7 +51,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_trainer_no_pipeline():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())

View File

@ -17,13 +17,16 @@ from torch.optim import Adam
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
BATCH_SIZE = 4
IMG_SIZE = 32
NUM_EPOCHS = 200
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2),)
CONFIG = dict(
NUM_MICRO_BATCHES=2,
parallel=dict(pipeline=2),
)
def run_trainer_with_pipeline(rank, world_size, port):
@ -85,7 +88,7 @@ def run_trainer_with_pipeline(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_trainer_with_pipeline():
world_size = 4
run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port())

View File

@ -1,6 +1,6 @@
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.zero.sharded_param import ShardedTensor
import colossalai
@ -35,7 +35,7 @@ def run_tensor_move(rank):
assert (tgt_t.device.type == 'cpu')
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_tensor_move():
mp.spawn(run_tensor_move, nprocs=1)

View File

@ -3,6 +3,7 @@ from functools import partial
from pathlib import Path
import colossalai
from colossalai.testing.utils import rerun_if_address_is_in_use
import pytest
import torch
import torch.multiprocessing as mp
@ -10,7 +11,7 @@ import torch.nn as nn
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import free_port, get_dataloader
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from torch.optim import Adam
from torchvision import transforms
from torchvision.datasets import CIFAR10
@ -87,7 +88,7 @@ def run_no_pipeline(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_engine():
world_size = 4
func = partial(run_no_pipeline, world_size=world_size, port=free_port())

View File

@ -16,7 +16,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
from functools import partial
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
def checkpoint_wrapper(module, enable=True):
@ -102,7 +102,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_zero_clip_grad():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())

View File

@ -6,7 +6,7 @@ import pytest
import torch
import torch.multiprocessing as mp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import BucketTensorShardStrategy
@ -62,7 +62,7 @@ def _run_dist(rank, world_size, port):
# use_cpuadam = True can be used with cpu_offload = False
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_found_inf(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -8,7 +8,7 @@ import pytest
import torch
import torch.multiprocessing as mp
from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer.model_data_memtracer import \
@ -64,7 +64,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_zero_init_context(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -10,7 +10,7 @@ from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.shard_utils import BucketTensorShardStrategy
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from functools import partial
@ -64,7 +64,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_mem_collector(world_size=2):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -7,7 +7,7 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
@ -59,7 +59,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_shard_model_v2(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -5,12 +5,11 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_param import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.testing import rerun_on_exception
from tests.test_zero.common import CONFIG, allclose
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
@ -37,7 +36,7 @@ def _run_shard_tensor(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_shard_tensor(world_size):
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
@ -85,7 +84,7 @@ def _run_shard_param_v2(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_shard_param_v2(world_size):
run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -8,7 +8,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
@ -105,7 +105,7 @@ def _run_dist(rank, world_size, port):
# use_cpuadam = True can be used with cpu_offload = False
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_sharded_optim_v2(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -10,7 +10,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy
@ -71,7 +71,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_sharded_optim_with_sync_bn():
"""
This test is to make sure that buffers are synchronized between ranks

View File

@ -8,7 +8,7 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
@ -49,7 +49,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_zero_state_dict(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -10,7 +10,7 @@ from colossalai.gemini import StatefulTensorMgr
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from torch.nn.parameter import Parameter
from typing import List
from functools import partial
@ -120,8 +120,8 @@ def run_dist(rank, world_size, port):
run_stm()
@pytest.mark.gpu
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_stateful_tensor_manager(world_size=1):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -6,6 +6,7 @@ from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
colo_model_tensor_clone)
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
import torch
@ -84,6 +85,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4, 5])
@rerun_if_address_is_in_use()
def test_zero_tensor_utils(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -9,7 +9,7 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.sharded_model.utils import col_model_deepcopy
@ -96,7 +96,7 @@ def run_dist(rank, world_size, port, parallel_config):
@pytest.mark.skip
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_mp_engine(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG)
mp.spawn(run_func, nprocs=world_size)
@ -104,7 +104,7 @@ def test_mp_engine(world_size):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_zero_engine(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG)
mp.spawn(run_func, nprocs=world_size)