import copy from functools import partial import pytest import torch import torch.multiprocessing as mp import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai from colossalai.testing.random import seed_all from colossalai.utils import free_port from colossalai.zero import LowLevelZeroOptimizer class TestModel(nn.Module): def __init__(self): super(TestModel, self).__init__() self.linear1 = nn.Linear(128, 256) self.linear2 = nn.Linear(256, 512) def forward(self, x): x = self.linear1(x) x = self.linear2(x) return x def half_close(a, b, loose=False): rtol = None atol = None if loose: rtol = 5e-2 atol = 5e-4 a = a.detach().half() b = b.detach().half() assert_close(a, b, rtol=rtol, atol=atol) def exam_zero_1_2(): """ In this test, we want to test whether zero stage 1 and 2 deliver the same numerical results despite different communication pattern we use these prefixes to differentiate the zero stage oss: partition optimizer states pg: partition gradients and optimizer states """ local_rank = torch.distributed.get_rank() seed_all(2001) # create model zero1_model = TestModel().cuda() zero2_model = copy.deepcopy(zero1_model) # create optimizer zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True) zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128) # create data seed_all(2001 + local_rank) input_data = torch.randn(32, 128).cuda() zero1_output = zero1_model(input_data) zero2_output = zero2_model(input_data) assert torch.equal(zero1_output, zero2_output) # zero-dp backward zero1_optimizer.backward(zero1_output.mean().float()) zero2_optimizer.backward(zero2_output.mean().float()) for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): if z2p.grad is not None: # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) assert torch.equal(z1p.grad, z2p.grad) zero1_optimizer.sync_grad() zero2_optimizer.sync_grad() # step zero1_optimizer.step() zero2_optimizer.step() # check updated param for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): assert torch.equal(z1p.data, z2p.data) def exam_zero_1_torch_ddp(): """ In this test, two pairs of model and optimizers are created. 1. zero: use sharded optimizer and fp16 parameters 2. torch: use torch DDP and fp32 parameters We feed these two sets of models with the same input and check if the differences in model output and updated parameters are within tolerance. """ local_rank = torch.distributed.get_rank() seed_all(1453) # create models zero_model = TestModel() torch_model = copy.deepcopy(zero_model) zero_model = zero_model.cuda().half() # torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) torch_model = torch_model.cuda() # for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): # half_close(p.data, z1p.data) # create optimizer zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) # we only test stage 1 here # in `check_sharded_param_consistency.py`, we will test whether # level 1 and 2 will produce exactly the same results zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=262144) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) seed_all(1453 + local_rank) # create input_data = torch.rand(32, 128).cuda() # zero-dp forward zero_output = zero_model(input_data.half()) # torch-ddp forward torch_output = torch_model(input_data) half_close(zero_output, torch_output, loose=True) # zero-dp backward zero_optimizer.backward(zero_output.mean().float()) # torch-ddp backward torch_output.mean().backward() # check grad for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): half_close(p.grad, z1p.grad, loose=True) # zero-dp step zero_optimizer.sync_grad() zero_optimizer.step() # torch ddp step torch_optimizer.step() # check updated param for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): # print(n, torch.max(torch.abs(p.data - z1p.data))) half_close(p.data, z1p.data, loose=True) def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') exam_zero_1_torch_ddp() exam_zero_1_2() @pytest.mark.dist def test_zero_1_2(): world_size = 2 run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': test_zero_1_2()