mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
63 lines
2.1 KiB
63 lines
2.1 KiB
import torch |
|
import torch.nn as nn |
|
from torch.optim.adam import Adam |
|
from torch.optim import AdamW |
|
|
|
from colossalai.nn.optimizer.fused_adam import FusedAdam |
|
from colossalai.testing import parameterize |
|
|
|
|
|
class FC(nn.Module): |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.fc = nn.Sequential(nn.Linear(64, 64)) |
|
|
|
def forward(self, x): |
|
return self.fc(x) |
|
|
|
|
|
@parameterize('adamw', [False, True]) |
|
@parameterize('p_dtype', [torch.float, torch.half]) |
|
@parameterize('g_dtype', [torch.float, torch.half]) |
|
def test_adam(adamw, p_dtype, g_dtype): |
|
model = FC().cuda().to(p_dtype) |
|
state = model.state_dict() |
|
model_copy = FC().cuda().to(p_dtype) |
|
model_copy.load_state_dict(state.copy()) |
|
|
|
if adamw: |
|
optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=True) |
|
torch_optim = AdamW(model_copy.parameters(), lr=1e-3) |
|
else: |
|
optim = FusedAdam(model.parameters(), lr=1e-3) |
|
torch_optim = Adam(model_copy.parameters(), lr=1e-3) |
|
|
|
data = torch.rand(1024, 64).cuda().to(p_dtype) |
|
data_copy = data.clone() |
|
label = torch.rand(1024, 64).cuda().to(p_dtype) |
|
|
|
for d, l in zip(data, label): |
|
y = model(d) |
|
loss = ((l - y)**2).sum() |
|
optim.zero_grad() |
|
loss.backward() |
|
if p_dtype != g_dtype: |
|
for i in range(len(optim.param_groups[0]['params'])): |
|
optim.param_groups[0]['params'][i].grad.data = optim.param_groups[0]['params'][i].grad.data.to(g_dtype) |
|
optim.step() |
|
|
|
for d, l in zip(data_copy, label): |
|
y = model_copy(d) |
|
loss = ((l - y)**2).sum() |
|
torch_optim.zero_grad() |
|
loss.backward() |
|
torch_optim.step() |
|
|
|
assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params']) |
|
|
|
for i in range(len(optim.param_groups[0]['params'])): |
|
if torch.isnan(optim.param_groups[0]['params'][i]).any() \ |
|
or torch.isnan(torch_optim.param_groups[0]['params'][i]).any(): |
|
continue |
|
assert torch.allclose(optim.param_groups[0]['params'][i], torch_optim.param_groups[0]['params'][i], 2e-3, 2e-3)
|
|
|