[cuda] modify the fused adam, support hybrid of fp16 and fp32 (#497)

pull/522/head
LuGY 2022-03-25 14:15:53 +08:00 committed by GitHub
parent 920c5889a7
commit 6a3f9fda83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 253 additions and 143 deletions

View File

@ -22,7 +22,7 @@ typedef enum
using MATH_T = float;
template <typename T>
template <typename T_g, typename T_p>
struct AdamFunctor
{
__device__ __forceinline__ void operator()(
@ -50,16 +50,16 @@ struct AdamFunctor
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T *g = (T *)tl.addresses[0][tensor_loc];
T_g *g = (T_g *)tl.addresses[0][tensor_loc];
g += chunk_idx * chunk_size;
T *p = (T *)tl.addresses[1][tensor_loc];
T_p *p = (T_p *)tl.addresses[1][tensor_loc];
p += chunk_idx * chunk_size;
T *m = (T *)tl.addresses[2][tensor_loc];
T_p *m = (T_p *)tl.addresses[2][tensor_loc];
m += chunk_idx * chunk_size;
T *v = (T *)tl.addresses[3][tensor_loc];
T_p *v = (T_p *)tl.addresses[3][tensor_loc];
v += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
@ -155,15 +155,15 @@ void multi_tensor_adam_cuda(
bias_correction2 = 1 - std::pow(beta2, step);
}
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF(
tensor_lists[0][0].scalar_type(), 0, "adam",
DISPATCH_FLOAT_AND_HALF_FOR_G_P(
tensor_lists[0][0].scalar_type(),
tensor_lists[1][0].scalar_type(), 0, "adam",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0>(),
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(),
beta1,
beta2,
bias_correction1,

View File

@ -173,6 +173,36 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) \
{ \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
} \
else if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Half) \
{ \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
} \
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Float) \
{ \
using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
} \
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) \
{ \
using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
} \
else \
{ \
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), "'"); \
} \
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes(T *x,
T val,

View File

@ -10,7 +10,7 @@ class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm.
Currently GPU-only. Requires ColossalAI to be installed via
``pip install -v --no-cache-dir --global-option="--cuda_ext" ./``.
``pip install .``.
This version of fused Adam implements 2 fusions.
@ -18,7 +18,7 @@ class FusedAdam(torch.optim.Optimizer):
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`colossalai.nn.optimizer.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
or ``torch.optim.Adam`` with ``adam_w_mode=False``
or ``torch.optim.Adam`` with ``adamw_mode=False``
:class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp.
@ -36,7 +36,7 @@ class FusedAdam(torch.optim.Optimizer):
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
adamw_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
@ -53,7 +53,7 @@ class FusedAdam(torch.optim.Optimizer):
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
adam_w_mode=True,
adamw_mode=True,
weight_decay=0.,
amsgrad=False,
set_grad_none=True):
@ -62,7 +62,7 @@ class FusedAdam(torch.optim.Optimizer):
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay)
super(FusedAdam, self).__init__(params, defaults)
self.adam_w_mode = 1 if adam_w_mode else 0
self.adamw_mode = 1 if adamw_mode else 0
self.set_grad_none = set_grad_none
if multi_tensor_applier.available:
import colossal_C
@ -109,8 +109,7 @@ class FusedAdam(torch.optim.Optimizer):
group['step'] = 1
# create lists for multi-tensor apply
g_16, p_16, m_16, v_16 = [], [], [], []
g_32, p_32, m_32, v_32 = [], [], [], []
g_l, p_l, m_l, v_l = [], [], [], []
for p in group['params']:
if p.grad is None:
@ -127,26 +126,16 @@ class FusedAdam(torch.optim.Optimizer):
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if p.dtype == torch.float16:
g_16.append(p.grad.data)
p_16.append(p.data)
m_16.append(state['exp_avg'])
v_16.append(state['exp_avg_sq'])
elif p.dtype == torch.float32:
g_32.append(p.grad.data)
p_32.append(p.data)
m_32.append(state['exp_avg'])
v_32.append(state['exp_avg_sq'])
else:
if p.dtype not in [torch.float16, torch.float32]:
raise RuntimeError('FusedAdam only support fp16 and fp32.')
g_l.append(p.grad.data)
p_l.append(p.data)
m_l.append(state['exp_avg'])
v_l.append(state['exp_avg_sq'])
if (len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_16, p_16, m_16, v_16],
group['lr'], beta1, beta2, group['eps'], group['step'], self.adam_w_mode,
bias_correction, group['weight_decay'])
if (len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_32, p_32, m_32, v_32],
group['lr'], beta1, beta2, group['eps'], group['step'], self.adam_w_mode,
bias_correction, group['weight_decay'])
multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l],
group['lr'], beta1, beta2, group['eps'], group['step'], self.adamw_mode,
bias_correction, group['weight_decay'])
return loss

View File

@ -1,38 +1,7 @@
# BSD 3-Clause License
#
# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the psutil authors nor the names of its contributors
# may be used to endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import math
import torch
try:
import cpu_adam
except ImportError:
raise ImportError("import cpu_adam error")
from colossalai.testing import parameterize
def torch_adam_update(
@ -71,45 +40,46 @@ def torch_adam_update(
param.addcdiv_(exp_avg, denom, value=-step_size)
class Test():
def assertLess(data_diff, threshold, msg):
assert data_diff < threshold, msg
def __init__(self):
self.opt_id = 0
def assertLess(self, data_diff, threshold, msg):
assert data_diff < threshold, msg
def assertTrue(condition, msg):
assert condition, msg
def assertTrue(self, condition, msg):
assert condition, msg
def check_res(
self,
step,
lr,
eps,
beta1,
beta2,
weight_decay,
shape,
grad_dtype,
loss_scale,
use_adamw,
cpu_adam_op,
):
p_data = torch.rand(shape, dtype=grad_dtype)
@parameterize('adamw', [True, False])
@parameterize('step', [1, 2])
@parameterize('loss_scale', [-1, 2 ** 5])
@parameterize('p_dtype', [torch.float, torch.half])
@parameterize('g_dtype', [torch.float, torch.half])
def test_cpu_adam(adamw, step, loss_scale, p_dtype, g_dtype):
lr = 1e-3
beta1, beta2 = 0.9, 0.999
eps = 1e-8
weight_decay = 0
for i in range(1024):
p_data = torch.rand(64, dtype=p_dtype)
p_data_copy = p_data.clone().float()
p_grad = torch.rand(shape, dtype=grad_dtype)
p_grad = torch.rand(64, dtype=g_dtype)
if loss_scale > 0:
p_grad.mul_(loss_scale)
p_grad_copy = p_grad.clone().float()
exp_avg = torch.rand(shape)
exp_avg = torch.rand(p_data.shape)
exp_avg_copy = exp_avg.clone()
exp_avg_sq = torch.rand(shape)
exp_avg_sq = torch.rand(p_data.shape)
exp_avg_sq_copy = exp_avg_sq.clone()
cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, use_adamw, True)
try:
import cpu_adam
cpu_adam_op = cpu_adam
except:
raise ImportError("...")
cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, adamw, False)
cpu_adam_op.adam_update(
self.opt_id,
0,
step,
lr,
beta1,
@ -136,62 +106,24 @@ class Test():
exp_avg_copy,
exp_avg_sq_copy,
loss_scale,
use_adamw,
adamw,
)
if loss_scale > 0:
p_grad.div_(loss_scale)
var = p_data_copy - p_data
data_diff = torch.max(torch.abs(var))
threshold = 2e-3 if grad_dtype else 1e-4
self.assertLess(
threshold = 1e-3
print(f"p_data diff {data_diff}. failed check, step {step}, lr {lr} eps "
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}")
assertLess(
data_diff,
threshold,
f"p_data diff {data_diff}. failed check, step {step}, lr {lr} eps "
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} loss_scale {loss_scale} grad_dtype {grad_dtype}",
f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, loss_scale {loss_scale}, eps "
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}",
)
max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad))
self.assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}")
assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}")
max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
self.assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
self.assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
def test_cpu_adam(self):
lr = 0.9
eps = 1e-6
weight_decay = 0
for use_adamw in [False, True]:
for shape in [(23,), (8, 24)]:
for step in range(1, 2):
for lr in [0.01]:
for eps in [1e-8]:
for beta1 in [0.9]:
for beta2 in [0.999]:
for weight_decay in [0.001]:
for grad_dtype in [torch.half, torch.float]:
for loss_scale in [-1, 2**5]:
self.check_res(
step,
lr,
eps,
beta1,
beta2,
weight_decay,
shape,
grad_dtype,
loss_scale,
use_adamw,
cpu_adam,
)
def test_cpu_adam():
test_case = Test()
test_case.test_cpu_adam()
if __name__ == "__main__":
test = Test()
test.test_cpu_adam()
assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")

View File

@ -0,0 +1,61 @@
import torch
import torch.nn as nn
from torch.optim.adam import Adam
from torch.optim import AdamW
from colossalai.nn.optimizer.fused_adam import FusedAdam
from colossalai.testing import parameterize
class FC(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc = nn.Sequential(nn.Linear(64, 64))
def forward(self, x):
return self.fc(x)
@parameterize('adamw', [False, True])
@parameterize('p_dtype', [torch.float, torch.half])
@parameterize('g_dtype', [torch.float, torch.half])
def test_adam(adamw, p_dtype, g_dtype):
model = FC().cuda().to(p_dtype)
state = model.state_dict()
model_copy = FC().cuda().to(p_dtype)
model_copy.load_state_dict(state.copy())
if adamw:
optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=True)
torch_optim = AdamW(model_copy.parameters(), lr=1e-3)
else:
optim = FusedAdam(model.parameters(), lr=1e-3)
torch_optim = Adam(model_copy.parameters(), lr=1e-3)
data = torch.rand(1024, 64).cuda().to(p_dtype)
data_copy = data.clone()
label = torch.rand(1024, 64).cuda().to(p_dtype)
for d, l in zip(data, label):
y = model(d)
loss = ((l - y) ** 2).sum()
optim.zero_grad()
loss.backward()
if p_dtype != g_dtype:
for i in range(len(optim.param_groups[0]['params'])):
optim.param_groups[0]['params'][i].grad.data = optim.param_groups[0]['params'][i].grad.data.to(g_dtype)
optim.step()
for d, l in zip(data_copy, label):
y = model_copy(d)
loss = ((l - y) ** 2).sum()
torch_optim.zero_grad()
loss.backward()
torch_optim.step()
assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params'])
for i in range(len(optim.param_groups[0]['params'])):
if torch.isnan(optim.param_groups[0]['params'][i]).any() \
or torch.isnan(torch_optim.param_groups[0]['params'][i]).any():
continue
assert torch.allclose(optim.param_groups[0]['params'][i], torch_optim.param_groups[0]['params'][i], 2e-3, 2e-3)

View File

@ -0,0 +1,98 @@
from numpy import dtype
import torch
import torch.nn as nn
import math
from colossalai.testing import parameterize
from colossalai.utils import multi_tensor_applier
def torch_adam_update(
step,
lr,
beta1,
beta2,
eps,
weight_decay,
param,
grad,
exp_avg,
exp_avg_sq,
loss_scale,
use_adamw,
):
if loss_scale > 0:
grad.div_(loss_scale)
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
if weight_decay != 0:
if use_adamw:
# Perform stepweight decay
param.mul_(1 - lr * weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
param.addcdiv_(exp_avg, denom, value=-step_size)
@parameterize('adamw', [False, True])
@parameterize('step', [1, 2])
@parameterize('p_dtype', [torch.float, torch.half])
@parameterize('g_dtype', [torch.float, torch.half])
def test_adam(adamw, step, p_dtype, g_dtype):
try:
import colossal_C
fused_adam = colossal_C.multi_tensor_adam
dummy_overflow_buf = torch.cuda.IntTensor([0])
except:
raise ImportError("No colossal_C kernel installed.")
count = 0
for i in range(1024):
p = torch.rand(64, dtype=p_dtype).cuda()
p_copy = p.clone().float()
g = torch.rand(p.shape, dtype=g_dtype).cuda()
g_copy = g.clone().float()
m = torch.rand(p.shape).cuda()
m_copy = m.clone()
v = torch.rand(p.shape).cuda()
v_copy = v.clone()
lr = 1e-3
beta1, beta2 = 0.9, 0.999
eps = 1e-8
weight_decay = 0
multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]],
lr, beta1, beta2, eps, step, adamw,
True, weight_decay)
torch_adam_update(
step,
lr,
beta1,
beta2,
eps,
weight_decay,
p_copy, # fp32 data
g_copy, # fp32 grad
m_copy,
v_copy,
-1,
adamw,
)
if torch.isnan(p).any() or torch.isnan(p_copy).any():
count += 1
continue
assert count < 200, "too many nans"
assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5, 1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}"