From 7fa6be49d2ac1eae2eda60f150597f0d3998ddf7 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 15 Feb 2023 09:43:29 +0800 Subject: [PATCH] [autoparallel] test compatibility for gemini and auto parallel (#2700) --- .../passes/runtime_preparation_pass.py | 10 +- .../test_compatibility_with_ddp.py | 98 ++++++++++++++++ .../test_compatibility_with_gemini.py | 108 ++++++++++++++++++ 3 files changed, 212 insertions(+), 4 deletions(-) create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 897602ce1..ecf3f1f18 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -377,8 +377,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o # TODO: build a ColoParamter class to manager the distributed parameters # we could use .data here, because all the operations just happen before the real training # loop, so we don't need to track these operations in the autograd graph. - param.data = shape_consistency_manager.apply_for_autoparallel_runtime( - param.data, param.sharding_spec, target_sharding_spec).detach().clone() + param = torch.nn.Parameter( + shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, + target_sharding_spec).detach().clone()) setattr(target_module, name, param) comm_actions = node.best_strategy.communication_actions @@ -432,8 +433,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o # TODO: build a ColoParamter class to manager the distributed parameters # we could use .data here, because all the operations just happen before the real training # loop, so we don't need to track these operations in the autograd graph. - target.data = shape_consistency_manager.apply_for_autoparallel_runtime( - target.data, target.sharding_spec, target_sharding_spec).detach().clone() + target = torch.nn.Parameter( + shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec, + target_sharding_spec).detach().clone()) assert hasattr(target_module, atoms[-1]) setattr(target_module, atoms[-1], target) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py new file mode 100644 index 000000000..365981f10 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py @@ -0,0 +1,98 @@ +import copy +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP + +from colossalai.auto_parallel.tensor_shard.initialize import initialize_model +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port + + +class MLP(torch.nn.Module): + + def __init__(self, in_features): + super().__init__() + self.linear_1 = torch.nn.Linear(in_features, 4 * in_features, bias=False) + self.linear_2 = torch.nn.Linear(4 * in_features, in_features, bias=False) + + def forward(self, x): + x = self.linear_1(x) + x = self.linear_2(x) + + return x + + +def check_compatibility_with_ddp(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = MLP(4).cuda() + input = torch.rand(4, 4).cuda() + output_compare = model(input) + loss_compare = output_compare.sum() + loss_compare.backward() + grad_compare = copy.deepcopy(model.linear_1.weight.grad) + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + meta_args = {'x': torch.rand(4, 4).to('meta')} + gm, solution = initialize_model(model, + meta_args=meta_args, + device_mesh=device_mesh, + return_solution=True, + solver_preference='tp', + shard_option='shard_last_axis') + + msg = '| TP strategy combination chosen by auto-parallel solver |' + msg_length = len(msg) + if rank == 0: + print('=' * msg_length) + print(msg) + print('=' * msg_length) + for strategy in solution: + print(strategy) + print('=' * msg_length) + + dp_process_group = None + for (ranks, process_group_handle) in device_mesh.process_groups_dict[0]: + if rank in ranks: + dp_process_group = process_group_handle + assert dp_process_group is not None + gm = DDP(gm, process_group=dp_process_group) + output = gm(input) + + assert_close(output, output_compare) + print(f'output on rank{rank} is correct') + loss = output.sum() + + loss.backward() + + if rank in (0, 2): + assert_close(gm.module.module.linear_1.weight.grad, grad_compare.narrow(0, 0, 8)) + + if rank in (1, 3): + assert_close(gm.module.module.linear_1.weight.grad, grad_compare.narrow(0, 8, 8)) + + print(f'gradient on rank{rank} is correct') + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_compatibility_with_ddp(): + world_size = 4 + run_func = partial(check_compatibility_with_ddp, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_compatibility_with_ddp() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py new file mode 100644 index 000000000..b4080c545 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -0,0 +1,108 @@ +import copy +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP + +from colossalai.auto_parallel.tensor_shard.initialize import initialize_model +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper +from colossalai.tensor.process_group import ProcessGroup +from colossalai.testing import assert_close, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port, get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx + + +class MLP(torch.nn.Module): + + def __init__(self, in_features): + super().__init__() + self.linear_1 = torch.nn.Linear(in_features, 4 * in_features, bias=False) + self.linear_2 = torch.nn.Linear(4 * in_features, in_features, bias=False) + + def forward(self, x): + x = self.linear_1(x) + x = self.linear_2(x) + + return x + + +def check_auto_parallel_with_gemini(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = MLP(4).half().cuda() + + input = torch.rand(4, 4).half().cuda() + output_compare = model(input) + loss_compare = output_compare.sum() + loss_compare.backward() + grad_compare = copy.deepcopy(model.linear_1.weight.grad) + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + meta_args = {'x': torch.rand(4, 4).half().to('meta')} + gm, solution = initialize_model(model, + meta_args=meta_args, + device_mesh=device_mesh, + return_solution=True, + solver_preference='tp', + shard_option='shard_last_axis') + + if rank == 0: + msg = '| TP strategy combination chosen by auto-parallel solver |' + msg_length = len(msg) + print('=' * msg_length) + print(msg) + print('=' * msg_length) + for strategy in solution: + print(strategy) + print('=' * msg_length) + + dp_process_group = ProcessGroup(rank=rank, ranks=[0, 1, 2, 3], tp_degree=2, dp_degree=2) + gemini_config = dict(strict_ddp_mode=False, + device=get_current_device(), + placement_policy='cpu', + pin_memory=True, + search_range_mb=128) + + post_process_colo_init_ctx(gm, device=get_current_device(), default_pg=dp_process_group) + gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config) + optimizer = HybridAdam(gm.parameters(), betas=(0, 0)) + optimizer = zero_optim_wrapper(gm, optimizer, initial_scale=1) + output = gm(input) + assert_close(output, output_compare) + print(f'output on rank{rank} is correct') + loss = output.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + + if rank in (0, 2): + assert_close(list(optimizer.optim.state.values())[0]['exp_avg'].half(), grad_compare.narrow(0, 0, 8).flatten()) + + if rank in (1, 3): + assert_close(list(optimizer.optim.state.values())[0]['exp_avg'].half(), grad_compare.narrow(0, 8, 8).flatten()) + + print(f'gradient on rank{rank} is correct') + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_auto_parallel_with_gemini(): + world_size = 4 + run_func = partial(check_auto_parallel_with_gemini, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_auto_parallel_with_gemini()