You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_optimizer/unittest_cpu_adam.py

130 lines
3.8 KiB

import math
import torch
from colossalai.testing import parameterize
def torch_adam_update(
step,
lr,
beta1,
beta2,
eps,
weight_decay,
param,
grad,
exp_avg,
exp_avg_sq,
loss_scale,
use_adamw,
):
if loss_scale > 0:
grad.div_(loss_scale)
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
if weight_decay != 0:
if use_adamw:
# Perform stepweight decay
param.mul_(1 - lr * weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
param.addcdiv_(exp_avg, denom, value=-step_size)
def assertLess(data_diff, threshold, msg):
assert data_diff < threshold, msg
def assertTrue(condition, msg):
assert condition, msg
@parameterize('adamw', [True, False])
@parameterize('step', [1, 2])
@parameterize('loss_scale', [-1, 2 ** 5])
@parameterize('p_dtype', [torch.float, torch.half])
@parameterize('g_dtype', [torch.float, torch.half])
def test_cpu_adam(adamw, step, loss_scale, p_dtype, g_dtype):
lr = 1e-3
beta1, beta2 = 0.9, 0.999
eps = 1e-8
weight_decay = 0
for i in range(1024):
p_data = torch.rand(64, dtype=p_dtype)
p_data_copy = p_data.clone().float()
p_grad = torch.rand(64, dtype=g_dtype)
if loss_scale > 0:
p_grad.mul_(loss_scale)
p_grad_copy = p_grad.clone().float()
exp_avg = torch.rand(p_data.shape)
exp_avg_copy = exp_avg.clone()
exp_avg_sq = torch.rand(p_data.shape)
exp_avg_sq_copy = exp_avg_sq.clone()
try:
import cpu_adam
cpu_adam_op = cpu_adam
except:
raise ImportError("...")
cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, adamw, False)
cpu_adam_op.adam_update(
0,
step,
lr,
beta1,
beta2,
eps,
weight_decay,
True,
p_data.view(-1), # fp32 data
p_grad.view(-1), # fp32 grad
exp_avg.view(-1),
exp_avg_sq.view(-1),
loss_scale,
)
torch_adam_update(
step,
lr,
beta1,
beta2,
eps,
weight_decay,
p_data_copy, # fp32 data
p_grad_copy, # fp32 grad
exp_avg_copy,
exp_avg_sq_copy,
loss_scale,
adamw,
)
if loss_scale > 0:
p_grad.div_(loss_scale)
var = p_data_copy - p_data
data_diff = torch.max(torch.abs(var))
threshold = 1e-3
print(f"p_data diff {data_diff}. failed check, step {step}, lr {lr} eps "
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}")
assertLess(
data_diff,
threshold,
f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, loss_scale {loss_scale}, eps "
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}",
)
max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad))
assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}")
max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")