Browse Source

[test] added missing decorators to model checkpointing tests

pull/728/head^2
FrankLeeeee 3 years ago committed by Frank Lee
parent
commit
62b4ce7326
  1. 6
      tests/test_utils/test_checkpoint/test_checkpoint_1d.py
  2. 6
      tests/test_utils/test_checkpoint/test_checkpoint_2d.py
  3. 6
      tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py
  4. 6
      tests/test_utils/test_checkpoint/test_checkpoint_3d.py

6
tests/test_utils/test_checkpoint/test_checkpoint_1d.py

@ -15,6 +15,7 @@ from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port, is_using_pp
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
from colossalai.testing import rerun_on_exception
def build_pipeline(model):
@ -38,9 +39,7 @@ def check_equal(A, B):
def check_checkpoint_1d(rank, world_size, port):
config = dict(
parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),
)
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),)
disable_existing_loggers()
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
@ -68,6 +67,7 @@ def check_checkpoint_1d(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_checkpoint_1d():
world_size = 8
run_func = partial(check_checkpoint_1d, world_size=world_size, port=free_port())

6
tests/test_utils/test_checkpoint/test_checkpoint_2d.py

@ -15,6 +15,7 @@ from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port, get_current_device, is_using_pp
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
from colossalai.testing import rerun_on_exception
def build_pipeline(model):
@ -38,9 +39,7 @@ def check_equal(A, B):
def check_checkpoint_2d(rank, world_size, port):
config = dict(
parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),
)
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),)
disable_existing_loggers()
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
@ -68,6 +67,7 @@ def check_checkpoint_2d(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_checkpoint_2d():
world_size = 8
run_func = partial(check_checkpoint_2d, world_size=world_size, port=free_port())

6
tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py

@ -15,6 +15,7 @@ from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port, get_current_device, is_using_pp
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
from colossalai.testing import rerun_on_exception
def build_pipeline(model):
@ -38,9 +39,7 @@ def check_equal(A, B):
def check_checkpoint_2p5d(rank, world_size, port):
config = dict(
parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),
)
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),)
disable_existing_loggers()
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
@ -68,6 +67,7 @@ def check_checkpoint_2p5d(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_checkpoint_2p5d():
world_size = 8
run_func = partial(check_checkpoint_2p5d, world_size=world_size, port=free_port())

6
tests/test_utils/test_checkpoint/test_checkpoint_3d.py

@ -15,6 +15,7 @@ from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port, get_current_device, is_using_pp
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
from colossalai.testing import rerun_on_exception
def build_pipeline(model):
@ -38,9 +39,7 @@ def check_equal(A, B):
def check_checkpoint_3d(rank, world_size, port):
config = dict(
parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),
)
config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),)
disable_existing_loggers()
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
@ -68,6 +67,7 @@ def check_checkpoint_3d(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_checkpoint_3d():
world_size = 8
run_func = partial(check_checkpoint_3d, world_size=world_size, port=free_port())

Loading…
Cancel
Save