#!/usr/bin/env python # -*- encoding: utf-8 -*- import copy from functools import partial import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ import colossalai from colossalai.logging import disable_existing_loggers from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port from colossalai.zero.legacy.shard_utils.tensor_shard_strategy import TensorShardStrategy from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2 def checkpoint_wrapper(module, enable=True): if enable: module.forward = partial(checkpoint, module.forward, False) return module class Net(nn.Module): def __init__(self, checkpoint=False) -> None: super().__init__() self.fc1 = nn.Linear(5, 5) self.fc2 = nn.Linear(5, 5) self.fc3 = nn.Linear(5, 1) if checkpoint: self.fc1 = checkpoint_wrapper(self.fc1) self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3] def forward(self, x): for layer in self.layers: x = layer(x) return x def run_step(model, optimizer, x, enable_autocast=False, norm_type=2.0): model.train() optimizer.zero_grad() with torch.cuda.amp.autocast(enabled=enable_autocast): y = model(x) loss = y.sum() loss = loss.float() loss.backward() clip_grad(model, norm_type) optimizer.step() def clip_grad(model, norm_type): if isinstance(model, DDP): clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=norm_type) else: clip_grad_norm_fp32(model.parameters(), max_norm=1.0, norm_type=norm_type) def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: if loose: return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) return torch.allclose(tensor_a, tensor_b) def check_grads(model, zero_model, loose=False): rank = dist.get_rank() for p, zero_p in zip(model.parameters(), zero_model.parameters()): zero_grad = zero_p.grad.clone().to(p.device) chunks = torch.flatten(p.grad).chunk(4) if rank >= len(chunks): continue grad = chunks[rank] if zero_p.zero_shard_padding > 0: zero_grad = zero_grad[:-zero_p.zero_shard_padding] assert grad.dtype == zero_grad.dtype assert allclose(grad, zero_grad, loose=loose) def check_params(model, zero_model, loose=False): rank = dist.get_rank() for p, zero_p in zip(model.parameters(), zero_model.parameters()): zero_shard_padding = zero_p.zero_shard_padding zero_p = zero_p.clone().to(p.device) chunks = torch.flatten(p).chunk(4) if rank >= len(chunks): continue p = chunks[rank] if zero_shard_padding > 0: zero_p = zero_p[:-zero_shard_padding] assert p.dtype == zero_p.dtype assert allclose(p, zero_p, loose=loose) def run_dist(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') @pytest.mark.dist @rerun_if_address_is_in_use() def test_zero_clip_grad(): world_size = 4 run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': test_zero_clip_grad()