from copy import deepcopy from typing import Type, Union import pytest import torch import torch.nn as nn from torch.optim import Adam, AdamW from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam from tests.kit.model_zoo import model_zoo _ALLOWED_OPTIM_DEVICES = [ (FusedAdam, torch.device("cuda:0")), (CPUAdam, torch.device("cpu")), (CPUAdam, torch.device("cuda:0")), (HybridAdam, torch.device("cpu")), (HybridAdam, torch.device("cuda:0")), ] _ALLOWED_P_G_TYPES = [ (torch.float, torch.float), # pure fp32 (torch.float, torch.half), # fp16 amp (torch.float, torch.bfloat16), # bfloat16 amp # (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16 # (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16 ] N_STEPS = 3 def setup_param_groups(bert_model: nn.Module) -> list: no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": 0.1, }, { "params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] return optimizer_grouped_parameters def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None: for p, torch_p in zip(model.parameters(), torch_model.parameters()): torch_p.grad = torch.rand_like(torch_p) # avoid inconsistent grad and param dtype error orig_p = p.data p.data = torch_p.grad.clone().to(g_dtype) p.grad = p.data p.data = orig_p @pytest.mark.parametrize("optim_cls, device", _ALLOWED_OPTIM_DEVICES) @pytest.mark.parametrize("adamw", [False, True]) @pytest.mark.parametrize("p_dtype, g_dtype", _ALLOWED_P_G_TYPES) def test_adam_optim_on_bert( optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]], device: torch.device, adamw: bool, p_dtype: torch.dtype, g_dtype: torch.dtype, ) -> None: model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_bert_for_sequence_classification").values())) torch_model = model_fn().to(device) model = deepcopy(torch_model).to(p_dtype) lr = 1e-3 beta1, beta2 = 0.9, 0.999 eps = 1e-8 torch_optim_cls = AdamW if adamw else Adam torch_optim = torch_optim_cls(setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps) optim = optim_cls(setup_param_groups(model), lr=lr, betas=(beta1, beta2), eps=eps, adamw_mode=adamw) rtol, atol = 1e-5, 1e-5 if p_dtype is torch.float16 or g_dtype is torch.float16: rtol, atol = 2e-3, 2e-3 if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16: rtol, atol = 4e-3, 4e-3 for _ in range(N_STEPS): set_grad(model, torch_model, g_dtype) torch_optim.step() optim.step() torch_optim.zero_grad() optim.zero_grad() for p, torch_p in zip(model.parameters(), torch_model.parameters()): # if overflow, the weight won't be updated. so there will be no nan in p assert not torch.isnan(p).any() assert torch.allclose(p.float(), torch_p, rtol=rtol, atol=atol)