diff --git a/colossalai/zero/sharded_model/utils.py b/colossalai/zero/sharded_model/utils.py new file mode 100644 index 000000000..7b7c634d3 --- /dev/null +++ b/colossalai/zero/sharded_model/utils.py @@ -0,0 +1,19 @@ +import torch +from colossalai.zero.sharded_model import ShardedModelV2 + +import copy + + +def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module): + """ + copy param of the ShardedModelV2 to other_model. + Note the other_model has to be the same as self. + """ + for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()): + assert hasattr(zero_param, 'col_attr') + shard_flag = zero_param.col_attr.data.is_sharded + if shard_flag: + sharded_model.shard_strategy.gather([zero_param.col_attr.data]) + param.data = copy.deepcopy(zero_param.col_attr.data.payload) + if shard_flag: + sharded_model.shard_strategy.shard([zero_param.col_attr.data]) diff --git a/tests/test_zero_data_parallel/common.py b/tests/test_zero_data_parallel/common.py index b4677f06f..163b098c0 100644 --- a/tests/test_zero_data_parallel/common.py +++ b/tests/test_zero_data_parallel/common.py @@ -3,8 +3,10 @@ from functools import partial import torch import torch.distributed as dist import torch.nn as nn + from colossalai.logging import get_dist_logger from colossalai.utils import checkpoint +from colossalai.zero.sharded_model import ShardedModelV2 LOGGER = get_dist_logger() @@ -20,6 +22,21 @@ CONFIG = dict(fp16=dict(mode=None,), parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) +def run_fwd_bwd(model, data, label, criterion, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + if isinstance(model, ShardedModelV2): + model.backward(loss) + else: + loss.backward() + + def checkpoint_wrapper(module, enable=True): if enable: module.forward = partial(checkpoint, module.forward) diff --git a/tests/test_zero_data_parallel/test_shard_model_v2.py b/tests/test_zero_data_parallel/test_shard_model_v2.py index 919b74ed3..a1885e8f0 100644 --- a/tests/test_zero_data_parallel/test_shard_model_v2.py +++ b/tests/test_zero_data_parallel/test_shard_model_v2.py @@ -3,81 +3,70 @@ import copy from functools import partial +import pytest + +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP import colossalai -import pytest -import torch -import torch.distributed as dist -import torch.multiprocessing as mp +from colossalai.zero.init_ctx import ZeroInitContext from colossalai.utils import free_port from colossalai.zero.shard_utils.tensor_shard_strategy import \ TensorShardStrategy from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16 + from tests.components_to_test.registry import non_distributed_component_funcs -from torch.nn.parallel import DistributedDataParallel as DDP - -from common import CONFIG, check_grads_padding +from common import CONFIG, check_grads_padding, run_fwd_bwd +from colossalai.zero.sharded_model.utils import col_model_deepcopy -def run_fwd_bwd(model, data, label, criterion, enable_autocast=False): - model.train() - with torch.cuda.amp.autocast(enabled=enable_autocast): - y = model(data) - loss = criterion(y, label) - loss = loss.float() - if isinstance(model, ShardedModelV2): - model.backward(loss) - else: - loss.backward() - - -# with no criterion -def run_fwd_bwd_no_criterion(model, data, label, enable_autocast=False): - model.train() - with torch.cuda.amp.autocast(enabled=enable_autocast): - loss = model(data, label) - if isinstance(model, ShardedModelV2): - model.backward(loss) - else: - loss.backward() - - -def run_dist(rank, world_size, port): +def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast): colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') test_models = ['repeated_computed_layers', 'resnet18', 'bert'] shard_strategy = TensorShardStrategy() for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name) - model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() - model = model(checkpoint=True).half().cuda() - zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy) - if dist.get_world_size() > 1: - model = DDP(model) + model_builder, train_dataloader, _, _, criterion = get_components_func() + + if use_zero_init_ctx: + with ZeroInitContext(convert_fp16=True, convert_cuda=True, shard_strategy=shard_strategy, shard_param=True): + zero_model = model_builder(checkpoint=True) + zero_model = ShardedModelV2(zero_model, shard_strategy) + + model = model_builder(checkpoint=True).half() + col_model_deepcopy(zero_model, model) + model = model.cuda() + else: + model = model_builder(checkpoint=True).half().cuda() + zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy) + + model = DDP(model) for i, (data, label) in enumerate(train_dataloader): - if i > 2: + if i > 3: break - if criterion is None: - data, label = data.cuda(), label.cuda() - run_fwd_bwd_no_criterion(model, data, label, False) - run_fwd_bwd_no_criterion(zero_model, data, label, False) - else: - data, label = cast_tensor_to_fp16(data).cuda(), label.cuda() - run_fwd_bwd(model, data, label, criterion, False) - run_fwd_bwd(zero_model, data, label, criterion, False) + data, label = cast_tensor_to_fp16(data).cuda(), label.cuda() + run_fwd_bwd(model, data, label, criterion, enable_autocast) + run_fwd_bwd(zero_model, data, label, criterion, enable_autocast) check_grads_padding(model, zero_model, loose=True) @pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2, 4]) -def test_shard_model_v2(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) +@pytest.mark.parametrize("world_size", [1, 2]) +@pytest.mark.parametrize("enable_autocast", [True]) +@pytest.mark.parametrize("use_zero_init_ctx", [True]) +def test_shard_model_v2(world_size, use_zero_init_ctx, enable_autocast): + run_func = partial(run_dist, + world_size=world_size, + port=free_port(), + use_zero_init_ctx=use_zero_init_ctx, + enable_autocast=enable_autocast) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': - test_shard_model_v2(world_size=2) + test_shard_model_v2(world_size=2, use_zero_init_ctx=True, enable_autocast=True) diff --git a/tests/test_zero_data_parallel/test_sharded_model_with_ctx.py b/tests/test_zero_data_parallel/test_sharded_model_with_ctx.py deleted file mode 100644 index 1dbcbd804..000000000 --- a/tests/test_zero_data_parallel/test_sharded_model_with_ctx.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import copy -from functools import partial - -import colossalai -import pytest -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from colossalai.utils import free_port -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils.tensor_shard_strategy import \ - TensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 -from tests.components_to_test.registry import non_distributed_component_funcs -from torch.nn.parallel import DistributedDataParallel as DDP - -from common import CONFIG, check_grads, check_grads_padding - - -def run_fwd_bwd(model, data, label, criterion, enable_autocast=False): - model.train() - with torch.cuda.amp.autocast(enabled=enable_autocast): - y = model(data) - loss = criterion(y, label) - loss = loss.float() - if isinstance(model, ShardedModelV2): - model.backward(loss) - else: - loss.backward() - - -def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - test_models = ['repeated_computed_layers', 'resnet18'] - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - shard_strategy = TensorShardStrategy() - with ZeroInitContext(convert_fp16=True, convert_cuda=True, shard_strategy=shard_strategy, shard_param=True): - zero_model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() - zero_model = zero_model() - model = copy.deepcopy(zero_model) - zero_model = ShardedModelV2(zero_model, shard_strategy) - model_state_dict = zero_model.state_dict() - for n, p in model.named_parameters(): - p.data = model_state_dict[n] - model = model.half().cuda() - if dist.get_world_size() > 1: - model = DDP(model) - - for i, (data, label) in enumerate(train_dataloader): - if i > 2: - break - data, label = data.half().cuda(), label.cuda() - run_fwd_bwd(model, data, label, criterion, False) - run_fwd_bwd(zero_model, data, label, criterion, False) - if dist.get_world_size() > 1: - check_grads_padding(model, zero_model, loose=True) - else: - check_grads(model, zero_model, loose=True) - - -@pytest.mark.dist -def test_shard_model_v2(): - world_size = 2 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_shard_model_v2() diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py b/tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py index 83cce1e30..ad0113578 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py @@ -78,7 +78,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2, 4]) +@pytest.mark.parametrize("world_size", [1, 2]) def test_sharded_optim_v2(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size)