diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py index 012b0ff43..4affa3789 100644 --- a/colossalai/auto_parallel/tensor_shard/initialize.py +++ b/colossalai/auto_parallel/tensor_shard/initialize.py @@ -330,6 +330,7 @@ def autoparallelize(model: nn.Module, device_mesh, solver_preference=solver_preference, dataloader_option=dataloader_option, + shard_option=shard_option, save_solver_solution=save_solver_solution, load_solver_solution=load_solver_solution, solution_path=solver_solution_path, diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py index 365981f10..e4982a5d7 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py @@ -33,11 +33,15 @@ def check_compatibility_with_ddp(rank, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = MLP(4).cuda() - input = torch.rand(4, 4).cuda() - output_compare = model(input) + if rank in [0, 1]: + input = torch.arange(0, 16, dtype=torch.float).reshape(4, 4).cuda() + elif rank in [2, 3]: + input = torch.arange(16, 32, dtype=torch.float).reshape(4, 4).cuda() + input_compare = torch.arange(0, 32, dtype=torch.float).reshape(8, 4).cuda() + output_compare = model(input_compare) loss_compare = output_compare.sum() loss_compare.backward() - grad_compare = copy.deepcopy(model.linear_1.weight.grad) + grad_compare = copy.deepcopy(model.linear_1.weight.grad / 2) physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -70,7 +74,10 @@ def check_compatibility_with_ddp(rank, world_size, port): gm = DDP(gm, process_group=dp_process_group) output = gm(input) - assert_close(output, output_compare) + if rank in (0, 1): + assert_close(output, output_compare.narrow(0, 0, 4)) + else: + assert_close(output, output_compare.narrow(0, 4, 4)) print(f'output on rank{rank} is correct') loss = output.sum() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py index b4080c545..760401c3f 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -37,12 +37,15 @@ def check_auto_parallel_with_gemini(rank, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = MLP(4).half().cuda() - - input = torch.rand(4, 4).half().cuda() - output_compare = model(input) + if rank in [0, 1]: + input = torch.arange(0, 16).reshape(4, 4).half().cuda() + elif rank in [2, 3]: + input = torch.arange(16, 32).reshape(4, 4).half().cuda() + input_compare = torch.arange(0, 32).reshape(8, 4).half().cuda() + output_compare = model(input_compare) loss_compare = output_compare.sum() loss_compare.backward() - grad_compare = copy.deepcopy(model.linear_1.weight.grad) + grad_compare = copy.deepcopy(model.linear_1.weight.grad / 2) physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -79,7 +82,10 @@ def check_auto_parallel_with_gemini(rank, world_size, port): optimizer = HybridAdam(gm.parameters(), betas=(0, 0)) optimizer = zero_optim_wrapper(gm, optimizer, initial_scale=1) output = gm(input) - assert_close(output, output_compare) + if rank in (0, 1): + assert_close(output, output_compare.narrow(0, 0, 4)) + else: + assert_close(output, output_compare.narrow(0, 4, 4)) print(f'output on rank{rank} is correct') loss = output.sum() optimizer.zero_grad()