mirror of https://github.com/hpcaitech/ColossalAI
[cuda] modify the fused adam, support hybrid of fp16 and fp32 (#497)
parent
920c5889a7
commit
6a3f9fda83
|
@ -22,7 +22,7 @@ typedef enum
|
|||
|
||||
using MATH_T = float;
|
||||
|
||||
template <typename T>
|
||||
template <typename T_g, typename T_p>
|
||||
struct AdamFunctor
|
||||
{
|
||||
__device__ __forceinline__ void operator()(
|
||||
|
@ -50,16 +50,16 @@ struct AdamFunctor
|
|||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
T *g = (T *)tl.addresses[0][tensor_loc];
|
||||
T_g *g = (T_g *)tl.addresses[0][tensor_loc];
|
||||
g += chunk_idx * chunk_size;
|
||||
|
||||
T *p = (T *)tl.addresses[1][tensor_loc];
|
||||
T_p *p = (T_p *)tl.addresses[1][tensor_loc];
|
||||
p += chunk_idx * chunk_size;
|
||||
|
||||
T *m = (T *)tl.addresses[2][tensor_loc];
|
||||
T_p *m = (T_p *)tl.addresses[2][tensor_loc];
|
||||
m += chunk_idx * chunk_size;
|
||||
|
||||
T *v = (T *)tl.addresses[3][tensor_loc];
|
||||
T_p *v = (T_p *)tl.addresses[3][tensor_loc];
|
||||
v += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
@ -155,15 +155,15 @@ void multi_tensor_adam_cuda(
|
|||
bias_correction2 = 1 - std::pow(beta2, step);
|
||||
}
|
||||
|
||||
// Assume single type across p,g,m1,m2 now
|
||||
DISPATCH_DOUBLE_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "adam",
|
||||
DISPATCH_FLOAT_AND_HALF_FOR_G_P(
|
||||
tensor_lists[0][0].scalar_type(),
|
||||
tensor_lists[1][0].scalar_type(), 0, "adam",
|
||||
multi_tensor_apply<4>(
|
||||
BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
AdamFunctor<scalar_t_0>(),
|
||||
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(),
|
||||
beta1,
|
||||
beta2,
|
||||
bias_correction1,
|
||||
|
|
|
@ -173,6 +173,36 @@
|
|||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
|
||||
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) \
|
||||
{ \
|
||||
using g_scalar_t_##LEVEL = float; \
|
||||
using p_scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
else if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Half) \
|
||||
{ \
|
||||
using g_scalar_t_##LEVEL = float; \
|
||||
using p_scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Float) \
|
||||
{ \
|
||||
using g_scalar_t_##LEVEL = at::Half; \
|
||||
using p_scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) \
|
||||
{ \
|
||||
using g_scalar_t_##LEVEL = at::Half; \
|
||||
using p_scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), "'"); \
|
||||
} \
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes(T *x,
|
||||
T val,
|
||||
|
|
|
@ -10,7 +10,7 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
"""Implements Adam algorithm.
|
||||
|
||||
Currently GPU-only. Requires ColossalAI to be installed via
|
||||
``pip install -v --no-cache-dir --global-option="--cuda_ext" ./``.
|
||||
``pip install .``.
|
||||
|
||||
This version of fused Adam implements 2 fusions.
|
||||
|
||||
|
@ -18,7 +18,7 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
|
||||
|
||||
:class:`colossalai.nn.optimizer.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
|
||||
or ``torch.optim.Adam`` with ``adam_w_mode=False``
|
||||
or ``torch.optim.Adam`` with ``adamw_mode=False``
|
||||
|
||||
:class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp.
|
||||
|
||||
|
@ -36,7 +36,7 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
(default: False) NOT SUPPORTED in FusedAdam!
|
||||
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
|
||||
adamw_mode (boolean, optional): Apply L2 regularization or weight decay
|
||||
True for decoupled weight decay(also known as AdamW) (default: True)
|
||||
set_grad_none (bool, optional): whether set grad to None when zero_grad()
|
||||
method is called. (default: True)
|
||||
|
@ -53,7 +53,7 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
bias_correction=True,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
adam_w_mode=True,
|
||||
adamw_mode=True,
|
||||
weight_decay=0.,
|
||||
amsgrad=False,
|
||||
set_grad_none=True):
|
||||
|
@ -62,7 +62,7 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
|
||||
defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super(FusedAdam, self).__init__(params, defaults)
|
||||
self.adam_w_mode = 1 if adam_w_mode else 0
|
||||
self.adamw_mode = 1 if adamw_mode else 0
|
||||
self.set_grad_none = set_grad_none
|
||||
if multi_tensor_applier.available:
|
||||
import colossal_C
|
||||
|
@ -109,8 +109,7 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
group['step'] = 1
|
||||
|
||||
# create lists for multi-tensor apply
|
||||
g_16, p_16, m_16, v_16 = [], [], [], []
|
||||
g_32, p_32, m_32, v_32 = [], [], [], []
|
||||
g_l, p_l, m_l, v_l = [], [], [], []
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
|
@ -127,26 +126,16 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
|
||||
if p.dtype == torch.float16:
|
||||
g_16.append(p.grad.data)
|
||||
p_16.append(p.data)
|
||||
m_16.append(state['exp_avg'])
|
||||
v_16.append(state['exp_avg_sq'])
|
||||
elif p.dtype == torch.float32:
|
||||
g_32.append(p.grad.data)
|
||||
p_32.append(p.data)
|
||||
m_32.append(state['exp_avg'])
|
||||
v_32.append(state['exp_avg_sq'])
|
||||
else:
|
||||
if p.dtype not in [torch.float16, torch.float32]:
|
||||
raise RuntimeError('FusedAdam only support fp16 and fp32.')
|
||||
|
||||
g_l.append(p.grad.data)
|
||||
p_l.append(p.data)
|
||||
m_l.append(state['exp_avg'])
|
||||
v_l.append(state['exp_avg_sq'])
|
||||
|
||||
if (len(g_16) > 0):
|
||||
multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_16, p_16, m_16, v_16],
|
||||
group['lr'], beta1, beta2, group['eps'], group['step'], self.adam_w_mode,
|
||||
bias_correction, group['weight_decay'])
|
||||
if (len(g_32) > 0):
|
||||
multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_32, p_32, m_32, v_32],
|
||||
group['lr'], beta1, beta2, group['eps'], group['step'], self.adam_w_mode,
|
||||
bias_correction, group['weight_decay'])
|
||||
multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l],
|
||||
group['lr'], beta1, beta2, group['eps'], group['step'], self.adamw_mode,
|
||||
bias_correction, group['weight_decay'])
|
||||
|
||||
return loss
|
||||
|
|
|
@ -1,38 +1,7 @@
|
|||
# BSD 3-Clause License
|
||||
#
|
||||
# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the psutil authors nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software without
|
||||
# specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
||||
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import math
|
||||
import torch
|
||||
try:
|
||||
import cpu_adam
|
||||
except ImportError:
|
||||
raise ImportError("import cpu_adam error")
|
||||
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
|
||||
def torch_adam_update(
|
||||
|
@ -71,45 +40,46 @@ def torch_adam_update(
|
|||
param.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
|
||||
class Test():
|
||||
def assertLess(data_diff, threshold, msg):
|
||||
assert data_diff < threshold, msg
|
||||
|
||||
def __init__(self):
|
||||
self.opt_id = 0
|
||||
|
||||
def assertLess(self, data_diff, threshold, msg):
|
||||
assert data_diff < threshold, msg
|
||||
def assertTrue(condition, msg):
|
||||
assert condition, msg
|
||||
|
||||
def assertTrue(self, condition, msg):
|
||||
assert condition, msg
|
||||
|
||||
def check_res(
|
||||
self,
|
||||
step,
|
||||
lr,
|
||||
eps,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
shape,
|
||||
grad_dtype,
|
||||
loss_scale,
|
||||
use_adamw,
|
||||
cpu_adam_op,
|
||||
):
|
||||
p_data = torch.rand(shape, dtype=grad_dtype)
|
||||
@parameterize('adamw', [True, False])
|
||||
@parameterize('step', [1, 2])
|
||||
@parameterize('loss_scale', [-1, 2 ** 5])
|
||||
@parameterize('p_dtype', [torch.float, torch.half])
|
||||
@parameterize('g_dtype', [torch.float, torch.half])
|
||||
def test_cpu_adam(adamw, step, loss_scale, p_dtype, g_dtype):
|
||||
lr = 1e-3
|
||||
beta1, beta2 = 0.9, 0.999
|
||||
eps = 1e-8
|
||||
weight_decay = 0
|
||||
|
||||
for i in range(1024):
|
||||
p_data = torch.rand(64, dtype=p_dtype)
|
||||
p_data_copy = p_data.clone().float()
|
||||
p_grad = torch.rand(shape, dtype=grad_dtype)
|
||||
p_grad = torch.rand(64, dtype=g_dtype)
|
||||
if loss_scale > 0:
|
||||
p_grad.mul_(loss_scale)
|
||||
p_grad_copy = p_grad.clone().float()
|
||||
exp_avg = torch.rand(shape)
|
||||
exp_avg = torch.rand(p_data.shape)
|
||||
exp_avg_copy = exp_avg.clone()
|
||||
exp_avg_sq = torch.rand(shape)
|
||||
exp_avg_sq = torch.rand(p_data.shape)
|
||||
exp_avg_sq_copy = exp_avg_sq.clone()
|
||||
|
||||
cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, use_adamw, True)
|
||||
try:
|
||||
import cpu_adam
|
||||
cpu_adam_op = cpu_adam
|
||||
except:
|
||||
raise ImportError("...")
|
||||
|
||||
cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, adamw, False)
|
||||
cpu_adam_op.adam_update(
|
||||
self.opt_id,
|
||||
0,
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
||||
|
@ -136,62 +106,24 @@ class Test():
|
|||
exp_avg_copy,
|
||||
exp_avg_sq_copy,
|
||||
loss_scale,
|
||||
use_adamw,
|
||||
adamw,
|
||||
)
|
||||
|
||||
if loss_scale > 0:
|
||||
p_grad.div_(loss_scale)
|
||||
|
||||
var = p_data_copy - p_data
|
||||
data_diff = torch.max(torch.abs(var))
|
||||
threshold = 2e-3 if grad_dtype else 1e-4
|
||||
self.assertLess(
|
||||
threshold = 1e-3
|
||||
print(f"p_data diff {data_diff}. failed check, step {step}, lr {lr} eps "
|
||||
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}")
|
||||
assertLess(
|
||||
data_diff,
|
||||
threshold,
|
||||
f"p_data diff {data_diff}. failed check, step {step}, lr {lr} eps "
|
||||
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} loss_scale {loss_scale} grad_dtype {grad_dtype}",
|
||||
f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, loss_scale {loss_scale}, eps "
|
||||
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}",
|
||||
)
|
||||
max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad))
|
||||
self.assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}")
|
||||
assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}")
|
||||
max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
|
||||
self.assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
|
||||
assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
|
||||
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
|
||||
self.assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
|
||||
|
||||
def test_cpu_adam(self):
|
||||
lr = 0.9
|
||||
eps = 1e-6
|
||||
weight_decay = 0
|
||||
for use_adamw in [False, True]:
|
||||
for shape in [(23,), (8, 24)]:
|
||||
for step in range(1, 2):
|
||||
for lr in [0.01]:
|
||||
for eps in [1e-8]:
|
||||
for beta1 in [0.9]:
|
||||
for beta2 in [0.999]:
|
||||
for weight_decay in [0.001]:
|
||||
for grad_dtype in [torch.half, torch.float]:
|
||||
for loss_scale in [-1, 2**5]:
|
||||
self.check_res(
|
||||
step,
|
||||
lr,
|
||||
eps,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
shape,
|
||||
grad_dtype,
|
||||
loss_scale,
|
||||
use_adamw,
|
||||
cpu_adam,
|
||||
)
|
||||
|
||||
|
||||
def test_cpu_adam():
|
||||
test_case = Test()
|
||||
test_case.test_cpu_adam()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test = Test()
|
||||
test.test_cpu_adam()
|
||||
assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim import AdamW
|
||||
|
||||
from colossalai.nn.optimizer.fused_adam import FusedAdam
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
|
||||
class FC(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Sequential(nn.Linear(64, 64))
|
||||
def forward(self, x):
|
||||
return self.fc(x)
|
||||
|
||||
|
||||
@parameterize('adamw', [False, True])
|
||||
@parameterize('p_dtype', [torch.float, torch.half])
|
||||
@parameterize('g_dtype', [torch.float, torch.half])
|
||||
def test_adam(adamw, p_dtype, g_dtype):
|
||||
model = FC().cuda().to(p_dtype)
|
||||
state = model.state_dict()
|
||||
model_copy = FC().cuda().to(p_dtype)
|
||||
model_copy.load_state_dict(state.copy())
|
||||
|
||||
if adamw:
|
||||
optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=True)
|
||||
torch_optim = AdamW(model_copy.parameters(), lr=1e-3)
|
||||
else:
|
||||
optim = FusedAdam(model.parameters(), lr=1e-3)
|
||||
torch_optim = Adam(model_copy.parameters(), lr=1e-3)
|
||||
|
||||
data = torch.rand(1024, 64).cuda().to(p_dtype)
|
||||
data_copy = data.clone()
|
||||
label = torch.rand(1024, 64).cuda().to(p_dtype)
|
||||
|
||||
for d, l in zip(data, label):
|
||||
y = model(d)
|
||||
loss = ((l - y) ** 2).sum()
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
if p_dtype != g_dtype:
|
||||
for i in range(len(optim.param_groups[0]['params'])):
|
||||
optim.param_groups[0]['params'][i].grad.data = optim.param_groups[0]['params'][i].grad.data.to(g_dtype)
|
||||
optim.step()
|
||||
|
||||
for d, l in zip(data_copy, label):
|
||||
y = model_copy(d)
|
||||
loss = ((l - y) ** 2).sum()
|
||||
torch_optim.zero_grad()
|
||||
loss.backward()
|
||||
torch_optim.step()
|
||||
|
||||
assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params'])
|
||||
|
||||
for i in range(len(optim.param_groups[0]['params'])):
|
||||
if torch.isnan(optim.param_groups[0]['params'][i]).any() \
|
||||
or torch.isnan(torch_optim.param_groups[0]['params'][i]).any():
|
||||
continue
|
||||
assert torch.allclose(optim.param_groups[0]['params'][i], torch_optim.param_groups[0]['params'][i], 2e-3, 2e-3)
|
|
@ -0,0 +1,98 @@
|
|||
from numpy import dtype
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import math
|
||||
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
|
||||
def torch_adam_update(
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
param,
|
||||
grad,
|
||||
exp_avg,
|
||||
exp_avg_sq,
|
||||
loss_scale,
|
||||
use_adamw,
|
||||
):
|
||||
if loss_scale > 0:
|
||||
grad.div_(loss_scale)
|
||||
bias_correction1 = 1 - beta1**step
|
||||
bias_correction2 = 1 - beta2**step
|
||||
|
||||
if weight_decay != 0:
|
||||
if use_adamw:
|
||||
# Perform stepweight decay
|
||||
param.mul_(1 - lr * weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
||||
|
||||
step_size = lr / bias_correction1
|
||||
|
||||
param.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
|
||||
@parameterize('adamw', [False, True])
|
||||
@parameterize('step', [1, 2])
|
||||
@parameterize('p_dtype', [torch.float, torch.half])
|
||||
@parameterize('g_dtype', [torch.float, torch.half])
|
||||
def test_adam(adamw, step, p_dtype, g_dtype):
|
||||
try:
|
||||
import colossal_C
|
||||
fused_adam = colossal_C.multi_tensor_adam
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
except:
|
||||
raise ImportError("No colossal_C kernel installed.")
|
||||
|
||||
count = 0
|
||||
|
||||
for i in range(1024):
|
||||
p = torch.rand(64, dtype=p_dtype).cuda()
|
||||
p_copy = p.clone().float()
|
||||
g = torch.rand(p.shape, dtype=g_dtype).cuda()
|
||||
g_copy = g.clone().float()
|
||||
m = torch.rand(p.shape).cuda()
|
||||
m_copy = m.clone()
|
||||
v = torch.rand(p.shape).cuda()
|
||||
v_copy = v.clone()
|
||||
|
||||
lr = 1e-3
|
||||
beta1, beta2 = 0.9, 0.999
|
||||
eps = 1e-8
|
||||
weight_decay = 0
|
||||
|
||||
multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]],
|
||||
lr, beta1, beta2, eps, step, adamw,
|
||||
True, weight_decay)
|
||||
|
||||
torch_adam_update(
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
p_copy, # fp32 data
|
||||
g_copy, # fp32 grad
|
||||
m_copy,
|
||||
v_copy,
|
||||
-1,
|
||||
adamw,
|
||||
)
|
||||
|
||||
if torch.isnan(p).any() or torch.isnan(p_copy).any():
|
||||
count += 1
|
||||
continue
|
||||
assert count < 200, "too many nans"
|
||||
assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5, 1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
Loading…
Reference in New Issue