|
|
|
@ -15,6 +15,7 @@ from colossalai.initialize import launch
|
|
|
|
|
from colossalai.logging import disable_existing_loggers |
|
|
|
|
from colossalai.utils import free_port, get_current_device, is_using_pp |
|
|
|
|
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint |
|
|
|
|
from colossalai.testing import rerun_on_exception |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_pipeline(model): |
|
|
|
@ -38,9 +39,7 @@ def check_equal(A, B):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_checkpoint_2p5d(rank, world_size, port): |
|
|
|
|
config = dict( |
|
|
|
|
parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")), |
|
|
|
|
) |
|
|
|
|
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),) |
|
|
|
|
|
|
|
|
|
disable_existing_loggers() |
|
|
|
|
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") |
|
|
|
@ -68,6 +67,7 @@ def check_checkpoint_2p5d(rank, world_size, port):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.dist |
|
|
|
|
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") |
|
|
|
|
def test_checkpoint_2p5d(): |
|
|
|
|
world_size = 8 |
|
|
|
|
run_func = partial(check_checkpoint_2p5d, world_size=world_size, port=free_port()) |
|
|
|
|