mirror of https://github.com/hpcaitech/ColossalAI
[test] added missing decorators to model checkpointing tests
parent
1cb7bdad3b
commit
62b4ce7326
|
@ -15,6 +15,7 @@ from colossalai.initialize import launch
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.utils import free_port, is_using_pp
|
from colossalai.utils import free_port, is_using_pp
|
||||||
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
||||||
|
from colossalai.testing import rerun_on_exception
|
||||||
|
|
||||||
|
|
||||||
def build_pipeline(model):
|
def build_pipeline(model):
|
||||||
|
@ -38,9 +39,7 @@ def check_equal(A, B):
|
||||||
|
|
||||||
|
|
||||||
def check_checkpoint_1d(rank, world_size, port):
|
def check_checkpoint_1d(rank, world_size, port):
|
||||||
config = dict(
|
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),)
|
||||||
parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),
|
|
||||||
)
|
|
||||||
|
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
@ -68,6 +67,7 @@ def check_checkpoint_1d(rank, world_size, port):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||||
def test_checkpoint_1d():
|
def test_checkpoint_1d():
|
||||||
world_size = 8
|
world_size = 8
|
||||||
run_func = partial(check_checkpoint_1d, world_size=world_size, port=free_port())
|
run_func = partial(check_checkpoint_1d, world_size=world_size, port=free_port())
|
||||||
|
|
|
@ -15,6 +15,7 @@ from colossalai.initialize import launch
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.utils import free_port, get_current_device, is_using_pp
|
from colossalai.utils import free_port, get_current_device, is_using_pp
|
||||||
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
||||||
|
from colossalai.testing import rerun_on_exception
|
||||||
|
|
||||||
|
|
||||||
def build_pipeline(model):
|
def build_pipeline(model):
|
||||||
|
@ -38,9 +39,7 @@ def check_equal(A, B):
|
||||||
|
|
||||||
|
|
||||||
def check_checkpoint_2d(rank, world_size, port):
|
def check_checkpoint_2d(rank, world_size, port):
|
||||||
config = dict(
|
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),)
|
||||||
parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),
|
|
||||||
)
|
|
||||||
|
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
@ -68,6 +67,7 @@ def check_checkpoint_2d(rank, world_size, port):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||||
def test_checkpoint_2d():
|
def test_checkpoint_2d():
|
||||||
world_size = 8
|
world_size = 8
|
||||||
run_func = partial(check_checkpoint_2d, world_size=world_size, port=free_port())
|
run_func = partial(check_checkpoint_2d, world_size=world_size, port=free_port())
|
||||||
|
|
|
@ -15,6 +15,7 @@ from colossalai.initialize import launch
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.utils import free_port, get_current_device, is_using_pp
|
from colossalai.utils import free_port, get_current_device, is_using_pp
|
||||||
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
||||||
|
from colossalai.testing import rerun_on_exception
|
||||||
|
|
||||||
|
|
||||||
def build_pipeline(model):
|
def build_pipeline(model):
|
||||||
|
@ -38,9 +39,7 @@ def check_equal(A, B):
|
||||||
|
|
||||||
|
|
||||||
def check_checkpoint_2p5d(rank, world_size, port):
|
def check_checkpoint_2p5d(rank, world_size, port):
|
||||||
config = dict(
|
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),)
|
||||||
parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),
|
|
||||||
)
|
|
||||||
|
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
@ -68,6 +67,7 @@ def check_checkpoint_2p5d(rank, world_size, port):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||||
def test_checkpoint_2p5d():
|
def test_checkpoint_2p5d():
|
||||||
world_size = 8
|
world_size = 8
|
||||||
run_func = partial(check_checkpoint_2p5d, world_size=world_size, port=free_port())
|
run_func = partial(check_checkpoint_2p5d, world_size=world_size, port=free_port())
|
||||||
|
|
|
@ -15,6 +15,7 @@ from colossalai.initialize import launch
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.utils import free_port, get_current_device, is_using_pp
|
from colossalai.utils import free_port, get_current_device, is_using_pp
|
||||||
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
||||||
|
from colossalai.testing import rerun_on_exception
|
||||||
|
|
||||||
|
|
||||||
def build_pipeline(model):
|
def build_pipeline(model):
|
||||||
|
@ -38,9 +39,7 @@ def check_equal(A, B):
|
||||||
|
|
||||||
|
|
||||||
def check_checkpoint_3d(rank, world_size, port):
|
def check_checkpoint_3d(rank, world_size, port):
|
||||||
config = dict(
|
config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),)
|
||||||
parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),
|
|
||||||
)
|
|
||||||
|
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
@ -68,6 +67,7 @@ def check_checkpoint_3d(rank, world_size, port):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||||
def test_checkpoint_3d():
|
def test_checkpoint_3d():
|
||||||
world_size = 8
|
world_size = 8
|
||||||
run_func = partial(check_checkpoint_3d, world_size=world_size, port=free_port())
|
run_func = partial(check_checkpoint_3d, world_size=world_size, port=free_port())
|
||||||
|
|
Loading…
Reference in New Issue