diff --git a/tests/test_zero_data_parallel/test_shard_model_v2.py b/tests/test_zero_data_parallel/test_shard_model_v2.py index 175abac10..56af46e67 100644 --- a/tests/test_zero_data_parallel/test_shard_model_v2.py +++ b/tests/test_zero_data_parallel/test_shard_model_v2.py @@ -30,12 +30,7 @@ def run_fwd_bwd(model, x, enable_autocast=False): def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = Net(checkpoint=True).cuda() zero_model = copy.deepcopy(model) @@ -52,11 +47,11 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -def test_shard_model_v2(): - world_size = 2 +@pytest.mark.parametrize("world_size", [1, 2, 4]) +def test_shard_model_v2(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': - test_shard_model_v2() + test_shard_model_v2(world_size=2) diff --git a/tests/test_zero_data_parallel/test_shard_param.py b/tests/test_zero_data_parallel/test_shard_param.py index bd04db77c..79bd8ee4c 100644 --- a/tests/test_zero_data_parallel/test_shard_param.py +++ b/tests/test_zero_data_parallel/test_shard_param.py @@ -4,19 +4,21 @@ from copy import deepcopy from functools import partial -import colossalai -from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 import pytest import torch import torch.multiprocessing as mp + +import colossalai +from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.sharded_param import ShardedTensor, ShardedParam from colossalai.utils import free_port from colossalai.logging import get_dist_logger, disable_existing_loggers + from tests.test_zero_data_parallel.common import Net, CONFIG, allclose -def run_shard_tensor(rank, world_size, port): +def _run_shard_tensor(rank, world_size, port): colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') t = ShardedTensor(tensor=torch.randn(world_size * 2, 3)) assert list(t.origin_shape) == [world_size * 2, 3] @@ -32,9 +34,9 @@ def run_shard_tensor(rank, world_size, port): @pytest.mark.dist -def test_shard_tensor(): - world_size = 2 - run_func = partial(run_shard_tensor, world_size=world_size, port=free_port()) +@pytest.mark.parametrize("world_size", [1, 2]) +def test_shard_tensor(world_size): + run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) @@ -52,8 +54,8 @@ def _run_shard_param_v2(rank, world_size, port): @pytest.mark.dist -def test_shard_param_v2(): - world_size = 2 +@pytest.mark.parametrize("world_size", [1, 2]) +def test_shard_param_v2(world_size): run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) @@ -86,40 +88,40 @@ def _run_test_shard_param(rank, world_size, port): @pytest.mark.dist -def test_shard_param(): - world_size = 2 +@pytest.mark.parametrize("world_size", [1, 2]) +def test_shard_param(world_size): run_func = partial(_run_test_shard_param, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) -def run_init_shard_param(rank, world_size, port): +def _run_init_shard_param(rank, world_size, port): colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - param = torch.nn.Parameter(data=torch.rand(2, 3)) + param = torch.nn.Parameter(data=torch.rand(world_size, 3)) sparam = ShardedParam(param, None, True) payload = sparam.payload(torch.device('cuda')) assert (list(payload.shape) == [3]) del sparam - param_shape = (2, 3) + param_shape = (world_size, 3) sparam = ShardedParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu')) payload = sparam.payload(torch.device('cuda')) assert (list(payload.shape) == [3]) - param_shape = (2, 3) + param_shape = (world_size, 3) sparam = ShardedParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu')) payload = sparam.payload(torch.device('cuda')) - assert (list(payload.shape) == [2, 3]) + assert (list(payload.shape) == [world_size, 3]) @pytest.mark.dist -def test_init_shard_param(): - world_size = 2 - run_func = partial(run_init_shard_param, world_size=world_size, port=free_port()) +@pytest.mark.parametrize("world_size", [1, 4]) +def test_init_shard_param(world_size): + run_func = partial(_run_init_shard_param, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': - test_shard_tensor() - test_shard_param() - test_shard_param_v2() - test_init_shard_param() + test_shard_tensor(2) + test_shard_param(2) + test_shard_param_v2(2) + test_init_shard_param(4) diff --git a/tests/test_zero_data_parallel/test_zero_param_mgr.py b/tests/test_zero_data_parallel/test_zero_param_mgr.py index a38ed9286..8171a0946 100644 --- a/tests/test_zero_data_parallel/test_zero_param_mgr.py +++ b/tests/test_zero_data_parallel/test_zero_param_mgr.py @@ -1,41 +1,39 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import os from functools import partial -from pathlib import Path - -import colossalai import pytest + import torch import torch.multiprocessing as mp + +import colossalai from colossalai.zero.sharded_model.param_manager import Zero3ParameterManager from colossalai.core import global_context as gpc from colossalai.context.parallel_mode import ParallelMode from colossalai.utils import free_port from common import CONFIG + def run_shard_shape_check(rank, world_size, port): - colossalai.launch(config=CONFIG, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') - + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = torch.nn.Linear(2, 4 * world_size) gpc.init_parallel_groups() - Zero3ParameterManager(module=model, process_group=gpc.get_group(ParallelMode.DATA), offload_config=CONFIG.get('offload_param_config')) + Zero3ParameterManager(module=model, + process_group=gpc.get_group(ParallelMode.DATA), + offload_config=CONFIG.get('offload_param_config')) - assert(model.weight.numel() == 4 * 2) - assert(model.bias.numel() == 4) + assert (model.weight.numel() == 4 * 2) + assert (model.bias.numel() == 4) @pytest.mark.dist -def test_run_shard_shape(): - world_size = 2 +@pytest.mark.parametrize("world_size", [1, 2, 4]) +def test_run_shard_shape(world_size): run_func = partial(run_shard_shape_check, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) + if __name__ == '__main__': - test_run_shard_shape() + test_run_shard_shape(2)