#!/usr/bin/env python # -*- encoding: utf-8 -*- import copy import operator as op from functools import partial, reduce from typing import List import colossalai import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn from colossalai.logging import disable_existing_loggers from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port from colossalai.zero.sharded_model import ShardedModel from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from colossalai.testing import parameterize def checkpoint_wrapper(module, enable=True): if enable: module.forward = partial(checkpoint, module.forward, False) return module class Net(nn.Module): def __init__(self, checkpoint=False) -> None: super().__init__() self.fc1 = nn.Linear(5, 5) self.fc2 = nn.Linear(5, 5) self.fc3 = nn.Linear(5, 1) if checkpoint: self.fc1 = checkpoint_wrapper(self.fc1) self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3] def forward(self, x): for layer in self.layers: x = layer(x) return x def run_step(model, optimizer, x, enable_autocast=False, norm_type=2.0): model.train() optimizer.zero_grad() with torch.cuda.amp.autocast(enabled=enable_autocast): y = model(x) loss = y.sum() loss = loss.float() loss.backward() clip_grad(model, norm_type) optimizer.step() def clip_grad(model, norm_type): if isinstance(model, DDP): clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=norm_type) else: clip_grad_norm_fp32(model.parameters(), max_norm=1.0, norm_type=norm_type) def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: if loose: return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) return torch.allclose(tensor_a, tensor_b) def check_grads(model, zero_model, loose=False): rank = dist.get_rank() for p, zero_p in zip(model.parameters(), zero_model.parameters()): zero_grad = zero_p.grad.clone().to(p.device) chunks = torch.flatten(p.grad).chunk(4) if rank >= len(chunks): continue grad = chunks[rank] if zero_p.zero_shard_padding > 0: zero_grad = zero_grad[:-zero_p.zero_shard_padding] assert grad.dtype == zero_grad.dtype assert allclose(grad, zero_grad, loose=loose) def check_params(model, zero_model, loose=False): rank = dist.get_rank() for p, zero_p in zip(model.parameters(), zero_model.parameters()): zero_shard_padding = zero_p.zero_shard_padding zero_p = zero_p.clone().to(p.device) chunks = torch.flatten(p).chunk(4) if rank >= len(chunks): continue p = chunks[rank] if zero_shard_padding > 0: zero_p = zero_p[:-zero_shard_padding] assert p.dtype == zero_p.dtype assert allclose(p, zero_p, loose=loose) @parameterize('checkpoint', [False, True]) @parameterize('fp16', [False, True]) @parameterize('offload', [False, True]) @parameterize('norm_type', [1.0, 2.0, float('inf')]) def check_config(checkpoint=False, fp16=False, offload=False, norm_type=2.0): model = Net(checkpoint=checkpoint).cuda() zero_model = copy.deepcopy(model) ddp_model = DDP(model) offload_config = {} if offload: offload_config['device'] = 'cpu' zero_model = zero_model.cpu() zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config) optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3) zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3) for _ in range(5): x = torch.rand(2, 5).cuda() run_step(ddp_model, optimizer, x, enable_autocast=fp16, norm_type=norm_type) run_step(zero_model, zero_optimizer, x, enable_autocast=fp16, norm_type=norm_type) check_grads(ddp_model, zero_model) check_params(ddp_model, zero_model) for _ in range(5): x = torch.rand(2, 5).cuda() run_step(ddp_model, optimizer, x, enable_autocast=False, norm_type=norm_type) run_step(zero_model, zero_optimizer, x, enable_autocast=False, norm_type=norm_type) check_grads(ddp_model, zero_model, loose=True) check_params(ddp_model, zero_model, loose=True) def run_dist(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') check_config() @pytest.mark.dist def test_zero_clip_grad(): world_size = 4 run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': test_zero_clip_grad()