diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py b/tests/test_utils/test_checkpoint/test_checkpoint_1d.py index ec677ad4e..0717c118b 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_1d.py @@ -15,6 +15,7 @@ from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.utils import free_port, is_using_pp from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint +from colossalai.testing import rerun_on_exception def build_pipeline(model): @@ -38,9 +39,7 @@ def check_equal(A, B): def check_checkpoint_1d(rank, world_size, port): - config = dict( - parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")), - ) + config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),) disable_existing_loggers() launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") @@ -68,6 +67,7 @@ def check_checkpoint_1d(rank, world_size, port): @pytest.mark.dist +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") def test_checkpoint_1d(): world_size = 8 run_func = partial(check_checkpoint_1d, world_size=world_size, port=free_port()) diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py b/tests/test_utils/test_checkpoint/test_checkpoint_2d.py index 76b6cecb5..42b39b91e 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_2d.py @@ -15,6 +15,7 @@ from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.utils import free_port, get_current_device, is_using_pp from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint +from colossalai.testing import rerun_on_exception def build_pipeline(model): @@ -38,9 +39,7 @@ def check_equal(A, B): def check_checkpoint_2d(rank, world_size, port): - config = dict( - parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")), - ) + config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),) disable_existing_loggers() launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") @@ -68,6 +67,7 @@ def check_checkpoint_2d(rank, world_size, port): @pytest.mark.dist +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") def test_checkpoint_2d(): world_size = 8 run_func = partial(check_checkpoint_2d, world_size=world_size, port=free_port()) diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py b/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py index 22bf8fbbf..7634a9706 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py @@ -15,6 +15,7 @@ from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.utils import free_port, get_current_device, is_using_pp from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint +from colossalai.testing import rerun_on_exception def build_pipeline(model): @@ -38,9 +39,7 @@ def check_equal(A, B): def check_checkpoint_2p5d(rank, world_size, port): - config = dict( - parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")), - ) + config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),) disable_existing_loggers() launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") @@ -68,6 +67,7 @@ def check_checkpoint_2p5d(rank, world_size, port): @pytest.mark.dist +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") def test_checkpoint_2p5d(): world_size = 8 run_func = partial(check_checkpoint_2p5d, world_size=world_size, port=free_port()) diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py b/tests/test_utils/test_checkpoint/test_checkpoint_3d.py index 1aa25dee8..740f3cfbd 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_3d.py @@ -15,6 +15,7 @@ from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.utils import free_port, get_current_device, is_using_pp from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint +from colossalai.testing import rerun_on_exception def build_pipeline(model): @@ -38,9 +39,7 @@ def check_equal(A, B): def check_checkpoint_3d(rank, world_size, port): - config = dict( - parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")), - ) + config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),) disable_existing_loggers() launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") @@ -68,6 +67,7 @@ def check_checkpoint_3d(rank, world_size, port): @pytest.mark.dist +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") def test_checkpoint_3d(): world_size = 8 run_func = partial(check_checkpoint_3d, world_size=world_size, port=free_port())