2022-03-04 08:05:15 +00:00
|
|
|
import math
|
2022-11-17 05:42:33 +00:00
|
|
|
|
2022-03-04 08:05:15 +00:00
|
|
|
import torch
|
2022-03-25 06:15:53 +00:00
|
|
|
|
2023-04-06 06:51:35 +00:00
|
|
|
from colossalai.testing import clear_cache_before_run, parameterize
|
2022-03-04 08:05:15 +00:00
|
|
|
|
2022-03-16 02:39:55 +00:00
|
|
|
|
2022-03-04 08:05:15 +00:00
|
|
|
def torch_adam_update(
|
|
|
|
step,
|
|
|
|
lr,
|
|
|
|
beta1,
|
|
|
|
beta2,
|
|
|
|
eps,
|
|
|
|
weight_decay,
|
|
|
|
param,
|
|
|
|
grad,
|
|
|
|
exp_avg,
|
|
|
|
exp_avg_sq,
|
|
|
|
use_adamw,
|
|
|
|
):
|
2022-03-16 02:39:55 +00:00
|
|
|
bias_correction1 = 1 - beta1**step
|
|
|
|
bias_correction2 = 1 - beta2**step
|
2022-03-04 08:05:15 +00:00
|
|
|
|
|
|
|
if weight_decay != 0:
|
|
|
|
if use_adamw:
|
|
|
|
# Perform stepweight decay
|
|
|
|
param.mul_(1 - lr * weight_decay)
|
|
|
|
else:
|
|
|
|
grad = grad.add(param, alpha=weight_decay)
|
|
|
|
|
|
|
|
# Decay the first and second moment running average coefficient
|
|
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
|
|
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
|
|
|
|
|
|
|
step_size = lr / bias_correction1
|
|
|
|
|
|
|
|
param.addcdiv_(exp_avg, denom, value=-step_size)
|
|
|
|
|
|
|
|
|
2022-03-25 06:15:53 +00:00
|
|
|
def assertLess(data_diff, threshold, msg):
|
|
|
|
assert data_diff < threshold, msg
|
|
|
|
|
|
|
|
|
|
|
|
def assertTrue(condition, msg):
|
|
|
|
assert condition, msg
|
|
|
|
|
|
|
|
|
2023-04-06 06:51:35 +00:00
|
|
|
@clear_cache_before_run()
|
2022-03-25 06:15:53 +00:00
|
|
|
@parameterize('adamw', [True, False])
|
|
|
|
@parameterize('step', [1, 2])
|
|
|
|
@parameterize('p_dtype', [torch.float, torch.half])
|
|
|
|
@parameterize('g_dtype', [torch.float, torch.half])
|
2022-03-25 10:03:54 +00:00
|
|
|
def test_cpu_adam(adamw, step, p_dtype, g_dtype):
|
2022-03-25 06:15:53 +00:00
|
|
|
lr = 1e-3
|
|
|
|
beta1, beta2 = 0.9, 0.999
|
|
|
|
eps = 1e-8
|
|
|
|
weight_decay = 0
|
2022-08-05 11:45:45 +00:00
|
|
|
|
2023-04-04 05:48:16 +00:00
|
|
|
for i in range(3):
|
2022-03-25 06:15:53 +00:00
|
|
|
p_data = torch.rand(64, dtype=p_dtype)
|
2022-03-04 08:05:15 +00:00
|
|
|
p_data_copy = p_data.clone().float()
|
2022-03-25 06:15:53 +00:00
|
|
|
p_grad = torch.rand(64, dtype=g_dtype)
|
2022-03-04 08:05:15 +00:00
|
|
|
p_grad_copy = p_grad.clone().float()
|
2022-03-25 06:15:53 +00:00
|
|
|
exp_avg = torch.rand(p_data.shape)
|
2022-03-04 08:05:15 +00:00
|
|
|
exp_avg_copy = exp_avg.clone()
|
2022-03-25 06:15:53 +00:00
|
|
|
exp_avg_sq = torch.rand(p_data.shape)
|
2022-03-04 08:05:15 +00:00
|
|
|
exp_avg_sq_copy = exp_avg_sq.clone()
|
|
|
|
|
2023-01-06 12:50:26 +00:00
|
|
|
from colossalai.kernel.op_builder import CPUAdamBuilder
|
|
|
|
cpu_optim = CPUAdamBuilder().load()
|
2022-03-25 06:15:53 +00:00
|
|
|
|
2022-12-23 08:05:13 +00:00
|
|
|
cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
|
|
|
|
|
2022-08-05 11:45:45 +00:00
|
|
|
cpu_adam_op.step(
|
2022-03-04 08:05:15 +00:00
|
|
|
step,
|
|
|
|
lr,
|
|
|
|
beta1,
|
|
|
|
beta2,
|
|
|
|
eps,
|
|
|
|
weight_decay,
|
|
|
|
True,
|
2022-03-16 02:39:55 +00:00
|
|
|
p_data.view(-1), # fp32 data
|
|
|
|
p_grad.view(-1), # fp32 grad
|
2022-03-04 08:05:15 +00:00
|
|
|
exp_avg.view(-1),
|
|
|
|
exp_avg_sq.view(-1),
|
2022-03-25 10:03:54 +00:00
|
|
|
-1,
|
2022-03-04 08:05:15 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
torch_adam_update(
|
|
|
|
step,
|
|
|
|
lr,
|
|
|
|
beta1,
|
|
|
|
beta2,
|
|
|
|
eps,
|
|
|
|
weight_decay,
|
2022-03-16 02:39:55 +00:00
|
|
|
p_data_copy, # fp32 data
|
|
|
|
p_grad_copy, # fp32 grad
|
2022-03-04 08:05:15 +00:00
|
|
|
exp_avg_copy,
|
|
|
|
exp_avg_sq_copy,
|
2022-03-25 06:15:53 +00:00
|
|
|
adamw,
|
2022-03-04 08:05:15 +00:00
|
|
|
)
|
|
|
|
var = p_data_copy - p_data
|
|
|
|
data_diff = torch.max(torch.abs(var))
|
2022-03-25 06:15:53 +00:00
|
|
|
threshold = 1e-3
|
|
|
|
assertLess(
|
2022-03-04 08:05:15 +00:00
|
|
|
data_diff,
|
|
|
|
threshold,
|
2022-03-25 10:03:54 +00:00
|
|
|
f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, eps "
|
2022-03-25 06:15:53 +00:00
|
|
|
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}",
|
2022-03-04 08:05:15 +00:00
|
|
|
)
|
|
|
|
max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad))
|
2022-03-25 06:15:53 +00:00
|
|
|
assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}")
|
2022-03-04 08:05:15 +00:00
|
|
|
max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
|
2022-03-25 06:15:53 +00:00
|
|
|
assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
|
2022-03-04 08:05:15 +00:00
|
|
|
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
|
2022-03-25 06:15:53 +00:00
|
|
|
assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
|
2022-12-23 06:14:21 +00:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
test_cpu_adam()
|