You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_optimizer/test_fused_adam.py

64 lines
2.1 KiB

import torch
import torch.nn as nn
from torch.optim.adam import Adam
from torch.optim import AdamW
from colossalai.nn.optimizer.fused_adam import FusedAdam
from colossalai.testing import parameterize
class FC(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc = nn.Sequential(nn.Linear(64, 64))
def forward(self, x):
return self.fc(x)
@parameterize('adamw', [False, True])
@parameterize('p_dtype', [torch.float, torch.half])
@parameterize('g_dtype', [torch.float, torch.half])
def test_adam(adamw, p_dtype, g_dtype):
model = FC().cuda().to(p_dtype)
state = model.state_dict()
model_copy = FC().cuda().to(p_dtype)
model_copy.load_state_dict(state.copy())
if adamw:
optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=True)
torch_optim = AdamW(model_copy.parameters(), lr=1e-3)
else:
optim = FusedAdam(model.parameters(), lr=1e-3)
torch_optim = Adam(model_copy.parameters(), lr=1e-3)
data = torch.rand(1024, 64).cuda().to(p_dtype)
data_copy = data.clone()
label = torch.rand(1024, 64).cuda().to(p_dtype)
for d, l in zip(data, label):
y = model(d)
loss = ((l - y)**2).sum()
optim.zero_grad()
loss.backward()
if p_dtype != g_dtype:
for i in range(len(optim.param_groups[0]['params'])):
optim.param_groups[0]['params'][i].grad.data = optim.param_groups[0]['params'][i].grad.data.to(g_dtype)
optim.step()
for d, l in zip(data_copy, label):
y = model_copy(d)
loss = ((l - y)**2).sum()
torch_optim.zero_grad()
loss.backward()
torch_optim.step()
assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params'])
for i in range(len(optim.param_groups[0]['params'])):
if torch.isnan(optim.param_groups[0]['params'][i]).any() \
or torch.isnan(torch_optim.param_groups[0]['params'][i]).any():
continue
assert torch.allclose(optim.param_groups[0]['params'][i], torch_optim.param_groups[0]['params'][i], 2e-3, 2e-3)