mirror of https://github.com/hpcaitech/ColossalAI
96 lines
2.6 KiB
Python
96 lines
2.6 KiB
Python
from numpy import dtype
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
import math
|
|
|
|
from colossalai.testing import parameterize
|
|
from colossalai.utils import multi_tensor_applier
|
|
|
|
|
|
def torch_adam_update(
|
|
step,
|
|
lr,
|
|
beta1,
|
|
beta2,
|
|
eps,
|
|
weight_decay,
|
|
param,
|
|
grad,
|
|
exp_avg,
|
|
exp_avg_sq,
|
|
use_adamw,
|
|
):
|
|
bias_correction1 = 1 - beta1**step
|
|
bias_correction2 = 1 - beta2**step
|
|
|
|
if weight_decay != 0:
|
|
if use_adamw:
|
|
# Perform stepweight decay
|
|
param.mul_(1 - lr * weight_decay)
|
|
else:
|
|
grad = grad.add(param, alpha=weight_decay)
|
|
|
|
# Decay the first and second moment running average coefficient
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
|
|
|
step_size = lr / bias_correction1
|
|
|
|
param.addcdiv_(exp_avg, denom, value=-step_size)
|
|
|
|
|
|
@parameterize('adamw', [False, True])
|
|
@parameterize('step', [1, 2])
|
|
@parameterize('p_dtype', [torch.float, torch.half])
|
|
@parameterize('g_dtype', [torch.float, torch.half])
|
|
def test_adam(adamw, step, p_dtype, g_dtype):
|
|
try:
|
|
import colossal_C
|
|
fused_adam = colossal_C.multi_tensor_adam
|
|
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
|
except:
|
|
raise ImportError("No colossal_C kernel installed.")
|
|
|
|
count = 0
|
|
|
|
for i in range(1024):
|
|
p = torch.rand(64, dtype=p_dtype).cuda()
|
|
p_copy = p.clone().float()
|
|
g = torch.rand(p.shape, dtype=g_dtype).cuda()
|
|
g_copy = g.clone().float()
|
|
m = torch.rand(p.shape).cuda()
|
|
m_copy = m.clone()
|
|
v = torch.rand(p.shape).cuda()
|
|
v_copy = v.clone()
|
|
|
|
lr = 1e-3
|
|
beta1, beta2 = 0.9, 0.999
|
|
eps = 1e-8
|
|
weight_decay = 0
|
|
|
|
multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw,
|
|
True, weight_decay)
|
|
|
|
torch_adam_update(
|
|
step,
|
|
lr,
|
|
beta1,
|
|
beta2,
|
|
eps,
|
|
weight_decay,
|
|
p_copy, # fp32 data
|
|
g_copy, # fp32 grad
|
|
m_copy,
|
|
v_copy,
|
|
adamw,
|
|
)
|
|
|
|
if torch.isnan(p).any() or torch.isnan(p_copy).any():
|
|
count += 1
|
|
continue
|
|
assert count < 200, "too many nans"
|
|
assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5,
|
|
1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|