using pytest parametrize

pull/394/head
jiaruifang 2022-03-08 12:03:35 +08:00 committed by Frank Lee
parent dec24561cf
commit 799d105bb4
3 changed files with 43 additions and 48 deletions

View File

@ -30,12 +30,7 @@ def run_fwd_bwd(model, x, enable_autocast=False):
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
model = Net(checkpoint=True).cuda() model = Net(checkpoint=True).cuda()
zero_model = copy.deepcopy(model) zero_model = copy.deepcopy(model)
@ -52,11 +47,11 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
def test_shard_model_v2(): @pytest.mark.parametrize("world_size", [1, 2, 4])
world_size = 2 def test_shard_model_v2(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_shard_model_v2() test_shard_model_v2(world_size=2)

View File

@ -4,19 +4,21 @@
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
import colossalai
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import colossalai
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_param import ShardedTensor, ShardedParam from colossalai.zero.sharded_param import ShardedTensor, ShardedParam
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.logging import get_dist_logger, disable_existing_loggers from colossalai.logging import get_dist_logger, disable_existing_loggers
from tests.test_zero_data_parallel.common import Net, CONFIG, allclose from tests.test_zero_data_parallel.common import Net, CONFIG, allclose
def run_shard_tensor(rank, world_size, port): def _run_shard_tensor(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3)) t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
assert list(t.origin_shape) == [world_size * 2, 3] assert list(t.origin_shape) == [world_size * 2, 3]
@ -32,9 +34,9 @@ def run_shard_tensor(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
def test_shard_tensor(): @pytest.mark.parametrize("world_size", [1, 2])
world_size = 2 def test_shard_tensor(world_size):
run_func = partial(run_shard_tensor, world_size=world_size, port=free_port()) run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
@ -52,8 +54,8 @@ def _run_shard_param_v2(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
def test_shard_param_v2(): @pytest.mark.parametrize("world_size", [1, 2])
world_size = 2 def test_shard_param_v2(world_size):
run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port()) run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
@ -86,40 +88,40 @@ def _run_test_shard_param(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
def test_shard_param(): @pytest.mark.parametrize("world_size", [1, 2])
world_size = 2 def test_shard_param(world_size):
run_func = partial(_run_test_shard_param, world_size=world_size, port=free_port()) run_func = partial(_run_test_shard_param, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
def run_init_shard_param(rank, world_size, port): def _run_init_shard_param(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
param = torch.nn.Parameter(data=torch.rand(2, 3)) param = torch.nn.Parameter(data=torch.rand(world_size, 3))
sparam = ShardedParam(param, None, True) sparam = ShardedParam(param, None, True)
payload = sparam.payload(torch.device('cuda')) payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [3]) assert (list(payload.shape) == [3])
del sparam del sparam
param_shape = (2, 3) param_shape = (world_size, 3)
sparam = ShardedParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu')) sparam = ShardedParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu'))
payload = sparam.payload(torch.device('cuda')) payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [3]) assert (list(payload.shape) == [3])
param_shape = (2, 3) param_shape = (world_size, 3)
sparam = ShardedParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu')) sparam = ShardedParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu'))
payload = sparam.payload(torch.device('cuda')) payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [2, 3]) assert (list(payload.shape) == [world_size, 3])
@pytest.mark.dist @pytest.mark.dist
def test_init_shard_param(): @pytest.mark.parametrize("world_size", [1, 4])
world_size = 2 def test_init_shard_param(world_size):
run_func = partial(run_init_shard_param, world_size=world_size, port=free_port()) run_func = partial(_run_init_shard_param, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_shard_tensor() test_shard_tensor(2)
test_shard_param() test_shard_param(2)
test_shard_param_v2() test_shard_param_v2(2)
test_init_shard_param() test_init_shard_param(4)

View File

@ -1,41 +1,39 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os
from functools import partial from functools import partial
from pathlib import Path
import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import colossalai
from colossalai.zero.sharded_model.param_manager import Zero3ParameterManager from colossalai.zero.sharded_model.param_manager import Zero3ParameterManager
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.utils import free_port from colossalai.utils import free_port
from common import CONFIG from common import CONFIG
def run_shard_shape_check(rank, world_size, port): def run_shard_shape_check(rank, world_size, port):
colossalai.launch(config=CONFIG, colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
model = torch.nn.Linear(2, 4 * world_size) model = torch.nn.Linear(2, 4 * world_size)
gpc.init_parallel_groups() gpc.init_parallel_groups()
Zero3ParameterManager(module=model, process_group=gpc.get_group(ParallelMode.DATA), offload_config=CONFIG.get('offload_param_config')) Zero3ParameterManager(module=model,
process_group=gpc.get_group(ParallelMode.DATA),
offload_config=CONFIG.get('offload_param_config'))
assert(model.weight.numel() == 4 * 2) assert (model.weight.numel() == 4 * 2)
assert(model.bias.numel() == 4) assert (model.bias.numel() == 4)
@pytest.mark.dist @pytest.mark.dist
def test_run_shard_shape(): @pytest.mark.parametrize("world_size", [1, 2, 4])
world_size = 2 def test_run_shard_shape(world_size):
run_func = partial(run_shard_shape_check, world_size=world_size, port=free_port()) run_func = partial(run_shard_shape_check, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_run_shard_shape() test_run_shard_shape(2)