polish unit test

pull/403/head
ver217 2022-03-14 15:06:02 +08:00
parent 88804aee49
commit 54fd37f0e0
6 changed files with 64 additions and 53 deletions

View File

@ -4,21 +4,20 @@
from functools import partial from functools import partial
import colossalai import colossalai
from colossalai.utils.cuda import get_current_device
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils.tensor_shard_strategy import \ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
TensorShardStrategy
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG from common import CONFIG
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
def run_dist(rank, world_size, port, init_device): def run_dist(rank, world_size, port, init_device, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for get_components_func in non_distributed_component_funcs: for get_components_func in non_distributed_component_funcs:
@ -26,7 +25,7 @@ def run_dist(rank, world_size, port, init_device):
model_numel_tensor = torch.zeros(1, dtype=torch.int) model_numel_tensor = torch.zeros(1, dtype=torch.int)
with ZeroInitContext(convert_fp16=True, with ZeroInitContext(convert_fp16=True,
target_device=init_device, target_device=init_device,
shard_strategy=TensorShardStrategy(), shard_strategy=shard_strategy(),
shard_param=True, shard_param=True,
model_numel_tensor=model_numel_tensor): model_numel_tensor=model_numel_tensor):
model = model_builder(checkpoint=True) model = model_builder(checkpoint=True)
@ -50,11 +49,16 @@ def run_dist(rank, world_size, port, init_device):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 4]) @pytest.mark.parametrize("world_size", [1, 4])
@pytest.mark.parametrize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')]) @pytest.mark.parametrize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
def test_zero_init_context(world_size, init_device): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(run_dist, world_size=world_size, port=free_port(), init_device=init_device) def test_zero_init_context(world_size, init_device, shard_strategy):
run_func = partial(run_dist,
world_size=world_size,
port=free_port(),
init_device=init_device,
shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_zero_init_context(2, torch.device('cpu')) test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
test_zero_init_context(2, torch.device(f'cuda:{get_current_device()}')) test_zero_init_context(2, torch.device(f'cuda:{get_current_device()}'), TensorShardStrategy)

View File

@ -3,30 +3,28 @@
import copy import copy
from functools import partial from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.zero.init_ctx import ZeroInitContext import pytest
import torch
import torch.multiprocessing as mp
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils.tensor_shard_strategy import \ from colossalai.zero.init_ctx import ZeroInitContext
TensorShardStrategy from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG, check_grads_padding, run_fwd_bwd
from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_model.utils import col_model_deepcopy
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads_padding, run_fwd_bwd
def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast): def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = TensorShardStrategy() shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func() model_builder, train_dataloader, _, _, criterion = get_components_func()
@ -66,14 +64,16 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("enable_autocast", [True]) @pytest.mark.parametrize("enable_autocast", [True])
@pytest.mark.parametrize("use_zero_init_ctx", [True]) @pytest.mark.parametrize("use_zero_init_ctx", [True])
def test_shard_model_v2(world_size, use_zero_init_ctx, enable_autocast): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def test_shard_model_v2(world_size, use_zero_init_ctx, enable_autocast, shard_strategy):
run_func = partial(run_dist, run_func = partial(run_dist,
world_size=world_size, world_size=world_size,
port=free_port(), port=free_port(),
use_zero_init_ctx=use_zero_init_ctx, use_zero_init_ctx=use_zero_init_ctx,
enable_autocast=enable_autocast) enable_autocast=enable_autocast,
shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_shard_model_v2(world_size=2, use_zero_init_ctx=True, enable_autocast=True) test_shard_model_v2(world_size=2, use_zero_init_ctx=True, enable_autocast=True, shard_strategy=TensorShardStrategy)

View File

@ -10,20 +10,20 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_param import ShardedParam, ShardedTensor from colossalai.zero.sharded_param import ShardedParam, ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from tests.test_zero_data_parallel.common import CONFIG, allclose
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_zero_data_parallel.common import CONFIG, allclose
def _run_shard_tensor(rank, world_size, port): def _run_shard_tensor(rank, world_size, port, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3)) t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
assert list(t.origin_shape) == [world_size * 2, 3] assert list(t.origin_shape) == [world_size * 2, 3]
assert list(t.shape) == [world_size * 2, 3] assert list(t.shape) == [world_size * 2, 3]
shard_strategy = TensorShardStrategy(process_group=None) shard_strategy = shard_strategy(process_group=None)
# test shard strategy # test shard strategy
shard_strategy.shard([t]) shard_strategy.shard([t])
@ -34,8 +34,9 @@ def _run_shard_tensor(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
def test_shard_tensor(world_size): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port()) def test_shard_tensor(world_size, shard_strategy):
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
@ -121,7 +122,7 @@ def test_init_shard_param(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_shard_tensor(2) test_shard_tensor(2, TensorShardStrategy)
test_shard_param(2) test_shard_param(2)
test_shard_param_v2(2) test_shard_param_v2(2)
test_init_shard_param(4) test_init_shard_param(4)

View File

@ -10,7 +10,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
@ -38,12 +38,12 @@ def run_step(model, optimizer, data, label, criterion, enable_autocast=False):
optimizer.step() optimizer.step()
def run_dist(rank, world_size, port, cpu_offload): def run_dist(rank, world_size, port, cpu_offload, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
shard_strategy = TensorShardStrategy()
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model(checkpoint=True).cuda() model = model(checkpoint=True).cuda()
zero_model = ShardedModelV2(copy.deepcopy(model), zero_model = ShardedModelV2(copy.deepcopy(model),
@ -69,10 +69,15 @@ def run_dist(rank, world_size, port, cpu_offload):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("cpu_offload", [True, False]) @pytest.mark.parametrize("cpu_offload", [True, False])
def test_sharded_optim_v2(world_size, cpu_offload): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(run_dist, world_size=world_size, port=free_port(), cpu_offload=cpu_offload) def test_sharded_optim_v2(world_size, cpu_offload, shard_strategy):
run_func = partial(run_dist,
world_size=world_size,
port=free_port(),
cpu_offload=cpu_offload,
shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_sharded_optim_v2(world_size=2, cpu_offload=True) test_sharded_optim_v2(world_size=2, cpu_offload=True, shard_strategy=TensorShardStrategy)

View File

@ -11,7 +11,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.nn.optimizer import CPUAdam from colossalai.nn.optimizer import CPUAdam
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
@ -47,12 +47,12 @@ def run_step_no_criterion(model, optimizer, data, label, enable_autocast=False):
optimizer.step() optimizer.step()
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
shard_strategy = TensorShardStrategy()
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model(checkpoint=True).cuda() model = model(checkpoint=True).cuda()
zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy, offload_config={'device': 'cpu'}) zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy, offload_config={'device': 'cpu'})
@ -79,10 +79,11 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
def test_sharded_optim_v2(world_size): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(run_dist, world_size=world_size, port=free_port()) def test_sharded_optim_v2(world_size, shard_strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_sharded_optim_v2(world_size=2) test_sharded_optim_v2(world_size=2, shard_strategy=TensorShardStrategy)

View File

@ -9,22 +9,21 @@ import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils.tensor_shard_strategy import \ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG from common import CONFIG
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18'] test_models = ['repeated_computed_layers', 'resnet18']
shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model_builder() model = model_builder()
shard_strategy = TensorShardStrategy()
model = model.half().cuda() model = model.half().cuda()
zero_model = ShardedModelV2(deepcopy(model), shard_strategy) zero_model = ShardedModelV2(deepcopy(model), shard_strategy)
zero_state_dict = zero_model.state_dict() zero_state_dict = zero_model.state_dict()
@ -33,11 +32,12 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
def test_zero_state_dict(): @pytest.mark.parametrize("world_size", [1, 2])
world_size = 2 @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(run_dist, world_size=world_size, port=free_port()) def test_zero_state_dict(world_size, shard_strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_zero_state_dict() test_zero_state_dict(2, TensorShardStrategy)