import copy import pytest import torch import torch.distributed as dist import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from colossalai.zero import LowLevelZeroOptimizer class MlpModel(nn.Module): def __init__(self): super(MlpModel, self).__init__() self.linear1 = nn.Linear(123, 253) self.linear_drop = nn.Linear(253, 253) self.linear2 = nn.Linear(253, 512) def forward(self, x): x = self.linear1(x) x = self.linear2(x) return x def loose_close(a, b, dtype: torch.dtype = torch.float32): rtol = None atol = None if dtype is torch.float16: rtol = 5e-2 atol = 5e-4 elif dtype is torch.bfloat16: rtol = 4e-3 atol = 4e-3 a = a.detach().to(dtype) b = b.detach().to(dtype) assert_close(a, b, rtol=rtol, atol=atol) def split_ddp_grad(grad, world_size): with torch.no_grad(): grad = grad.clone().detach().flatten() padding_size = (world_size - grad.numel() % world_size) % world_size if padding_size > 0: grad = torch.nn.functional.pad(grad, [0, padding_size]) splited_grad = grad.split(grad.numel() // world_size) return splited_grad @parameterize("fp8_communication", [True, False]) def exam_zero_1_2(fp8_communication: bool): """ In this test, we want to test whether zero stage 1 and 2 deliver the same numerical results despite different communication pattern we use these prefixes to differentiate the zero stage oss: partition optimizer states pg: partition gradients and optimizer states """ local_rank = torch.distributed.get_rank() seed_all(2001) # create model zero1_model = MlpModel().cuda() zero2_model = copy.deepcopy(zero1_model) # create optimizer zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) zero1_optimizer = LowLevelZeroOptimizer( zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True, fp8_communication=fp8_communication, ) zero2_optimizer = LowLevelZeroOptimizer( zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128, fp8_communication=fp8_communication, ) # create data seed_all(2001 + local_rank) input_data = torch.randn(32, 123).cuda() zero1_output = zero1_model(input_data) zero2_output = zero2_model(input_data) assert torch.equal(zero1_output, zero2_output) # zero-dp backward zero1_optimizer.backward(zero1_output.mean().float()) zero2_optimizer.backward(zero2_output.mean().float()) # check grad for p1, p2 in zip(zero1_model.parameters(), zero2_model.parameters()): g1 = zero1_optimizer.get_param_grad(p1) g2 = zero2_optimizer.get_param_grad(p2) if g1 is None or g2 is None: assert g1 is None and g2 is None continue if fp8_communication: loose_close(g1, g2, dtype=torch.float16) else: assert torch.allclose(g1, g2) # step zero1_optimizer.step() zero2_optimizer.step() # check updated param for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): if not fp8_communication: assert torch.allclose(z1p, z2p) @parameterize("dtype", [torch.float16, torch.bfloat16]) @parameterize("master_weights", [True, False]) @parameterize("extra_dp_size", [1, 2]) def exam_zero_1_torch_ddp(dtype: torch.dtype, master_weights: bool, extra_dp_size: int): """ In this test, two pairs of model and optimizers are created. 1. zero: use sharded optimizer and fp16 parameters 2. torch: use torch DDP and fp32 parameters We feed these two sets of models with the same input and check if the differences in model output and updated parameters are within tolerance. """ if extra_dp_size > 1 and dtype != torch.bfloat16: return if extra_dp_size > 1: pg_mesh = ProcessGroupMesh(extra_dp_size, dist.get_world_size() // extra_dp_size) extra_dp_group = pg_mesh.get_group_along_axis(0) dp_group = pg_mesh.get_group_along_axis(1) else: extra_dp_group = None dp_group = None local_rank = torch.distributed.get_rank() seed_all(1453) # create models torch_model = MlpModel().cuda().to(dtype) zero_model = copy.deepcopy(torch_model).to(dtype) torch_model = DDP(torch_model.cuda(), static_graph=True).cuda() # create optimizer zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) # we only test stage 1 here # in `check_sharded_param_consistency.py`, we will test whether # level 1 and 2 will produce exactly the same results zero_optimizer = LowLevelZeroOptimizer( zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=1024 * 1024, master_weights=master_weights, dp_process_group=dp_group, extra_dp_group=extra_dp_group, ) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) seed_all(1453 + local_rank) for _ in range(2): # create input_data = torch.rand(32, 123).cuda().to(dtype) # zero-dp forward zero_output = zero_model(input_data) # torch-ddp forward torch_output = torch_model(input_data) loose_close(zero_output, torch_output, dtype=dtype) # zero-dp backward zero_optimizer.backward(zero_output.mean()) # torch-ddp backward torch_output.mean().backward() # check grad for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): zero_grad = zero_optimizer.get_param_grad(z1p) if p.grad is None: assert zero_grad is None continue loose_close(p.grad, zero_grad, dtype=dtype) # zero-dp step zero_optimizer.step() # torch ddp step torch_optimizer.step() zero_optimizer._force_wait_all_gather() # check updated param for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): loose_close(p, z1p, dtype=dtype) def run_dist(rank, world_size, port): colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") exam_zero_1_torch_ddp() exam_zero_1_2() @pytest.mark.dist @rerun_if_address_is_in_use() def test_zero_1_2(): spawn(run_dist, 4) if __name__ == "__main__": test_zero_1_2()