mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
62 lines
2.1 KiB
62 lines
2.1 KiB
import torch
|
|
import torch.nn as nn
|
|
from torch.optim.adam import Adam
|
|
from torch.optim import AdamW
|
|
|
|
from colossalai.nn.optimizer.fused_adam import FusedAdam
|
|
from colossalai.testing import parameterize
|
|
|
|
|
|
class FC(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.fc = nn.Sequential(nn.Linear(64, 64))
|
|
def forward(self, x):
|
|
return self.fc(x)
|
|
|
|
|
|
@parameterize('adamw', [False, True])
|
|
@parameterize('p_dtype', [torch.float, torch.half])
|
|
@parameterize('g_dtype', [torch.float, torch.half])
|
|
def test_adam(adamw, p_dtype, g_dtype):
|
|
model = FC().cuda().to(p_dtype)
|
|
state = model.state_dict()
|
|
model_copy = FC().cuda().to(p_dtype)
|
|
model_copy.load_state_dict(state.copy())
|
|
|
|
if adamw:
|
|
optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=True)
|
|
torch_optim = AdamW(model_copy.parameters(), lr=1e-3)
|
|
else:
|
|
optim = FusedAdam(model.parameters(), lr=1e-3)
|
|
torch_optim = Adam(model_copy.parameters(), lr=1e-3)
|
|
|
|
data = torch.rand(1024, 64).cuda().to(p_dtype)
|
|
data_copy = data.clone()
|
|
label = torch.rand(1024, 64).cuda().to(p_dtype)
|
|
|
|
for d, l in zip(data, label):
|
|
y = model(d)
|
|
loss = ((l - y) ** 2).sum()
|
|
optim.zero_grad()
|
|
loss.backward()
|
|
if p_dtype != g_dtype:
|
|
for i in range(len(optim.param_groups[0]['params'])):
|
|
optim.param_groups[0]['params'][i].grad.data = optim.param_groups[0]['params'][i].grad.data.to(g_dtype)
|
|
optim.step()
|
|
|
|
for d, l in zip(data_copy, label):
|
|
y = model_copy(d)
|
|
loss = ((l - y) ** 2).sum()
|
|
torch_optim.zero_grad()
|
|
loss.backward()
|
|
torch_optim.step()
|
|
|
|
assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params'])
|
|
|
|
for i in range(len(optim.param_groups[0]['params'])):
|
|
if torch.isnan(optim.param_groups[0]['params'][i]).any() \
|
|
or torch.isnan(torch_optim.param_groups[0]['params'][i]).any():
|
|
continue
|
|
assert torch.allclose(optim.param_groups[0]['params'][i], torch_optim.param_groups[0]['params'][i], 2e-3, 2e-3)
|