mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
198 lines
6.7 KiB
198 lines
6.7 KiB
3 years ago
|
# BSD 3-Clause License
|
||
|
#
|
||
|
# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
|
||
|
#
|
||
|
# Redistribution and use in source and binary forms, with or without modification,
|
||
|
# are permitted provided that the following conditions are met:
|
||
|
#
|
||
|
# * Redistributions of source code must retain the above copyright notice, this
|
||
|
# list of conditions and the following disclaimer.
|
||
|
#
|
||
|
# * Redistributions in binary form must reproduce the above copyright notice,
|
||
|
# this list of conditions and the following disclaimer in the documentation
|
||
|
# and/or other materials provided with the distribution.
|
||
|
#
|
||
|
# * Neither the name of the psutil authors nor the names of its contributors
|
||
|
# may be used to endorse or promote products derived from this software without
|
||
|
# specific prior written permission.
|
||
|
#
|
||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||
|
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||
|
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||
|
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||
|
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||
|
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
||
|
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||
|
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||
|
|
||
|
import math
|
||
|
import torch
|
||
|
import colossalai
|
||
|
try:
|
||
|
import cpu_adam
|
||
|
except ImportError:
|
||
|
raise ImportError("import cpu_adam error")
|
||
|
|
||
|
def torch_adam_update(
|
||
|
step,
|
||
|
lr,
|
||
|
beta1,
|
||
|
beta2,
|
||
|
eps,
|
||
|
weight_decay,
|
||
|
bias_correction,
|
||
|
param,
|
||
|
grad,
|
||
|
exp_avg,
|
||
|
exp_avg_sq,
|
||
|
loss_scale,
|
||
|
use_adamw,
|
||
|
):
|
||
|
if loss_scale > 0:
|
||
|
grad.div_(loss_scale)
|
||
|
bias_correction1 = 1 - beta1 ** step
|
||
|
bias_correction2 = 1 - beta2 ** step
|
||
|
|
||
|
if weight_decay != 0:
|
||
|
if use_adamw:
|
||
|
# Perform stepweight decay
|
||
|
param.mul_(1 - lr * weight_decay)
|
||
|
else:
|
||
|
grad = grad.add(param, alpha=weight_decay)
|
||
|
|
||
|
# Decay the first and second moment running average coefficient
|
||
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||
|
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
||
|
|
||
|
step_size = lr / bias_correction1
|
||
|
|
||
|
param.addcdiv_(exp_avg, denom, value=-step_size)
|
||
|
|
||
|
|
||
|
class Test():
|
||
|
def __init__(self):
|
||
|
self.opt_id = 0
|
||
|
|
||
|
def assertLess(self, data_diff, threshold, msg):
|
||
|
assert data_diff < threshold, msg
|
||
|
|
||
|
def assertTrue(self, condition, msg):
|
||
|
assert condition, msg
|
||
|
|
||
|
def check_res(
|
||
|
self,
|
||
|
step,
|
||
|
lr,
|
||
|
eps,
|
||
|
beta1,
|
||
|
beta2,
|
||
|
|
||
|
weight_decay,
|
||
|
shape,
|
||
|
grad_dtype,
|
||
|
loss_scale,
|
||
|
use_adamw,
|
||
|
cpu_adam_op,
|
||
|
):
|
||
|
p_data = torch.rand(shape, dtype=grad_dtype)
|
||
|
p_data_copy = p_data.clone().float()
|
||
|
p_grad = torch.rand(shape, dtype=grad_dtype)
|
||
|
if loss_scale > 0:
|
||
|
p_grad.mul_(loss_scale)
|
||
|
p_grad_copy = p_grad.clone().float()
|
||
|
exp_avg = torch.rand(shape)
|
||
|
exp_avg_copy = exp_avg.clone()
|
||
|
exp_avg_sq = torch.rand(shape)
|
||
|
exp_avg_sq_copy = exp_avg_sq.clone()
|
||
|
|
||
|
cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, use_adamw, True)
|
||
|
cpu_adam_op.adam_update(
|
||
|
self.opt_id,
|
||
|
step,
|
||
|
lr,
|
||
|
beta1,
|
||
|
beta2,
|
||
|
eps,
|
||
|
weight_decay,
|
||
|
True,
|
||
|
p_data.view(-1), # fp32 data
|
||
|
p_grad.view(-1), # fp32 grad
|
||
|
exp_avg.view(-1),
|
||
|
exp_avg_sq.view(-1),
|
||
|
loss_scale,
|
||
|
)
|
||
|
|
||
|
torch_adam_update(
|
||
|
step,
|
||
|
lr,
|
||
|
beta1,
|
||
|
beta2,
|
||
|
eps,
|
||
|
weight_decay,
|
||
|
True,
|
||
|
p_data_copy, # fp32 data
|
||
|
p_grad_copy, # fp32 grad
|
||
|
exp_avg_copy,
|
||
|
exp_avg_sq_copy,
|
||
|
loss_scale,
|
||
|
use_adamw,
|
||
|
)
|
||
|
|
||
|
if loss_scale > 0:
|
||
|
p_grad.div_(loss_scale)
|
||
|
|
||
|
var = p_data_copy - p_data
|
||
|
data_diff = torch.max(torch.abs(var))
|
||
|
threshold = 2e-3 if grad_dtype else 1e-4
|
||
|
self.assertLess(
|
||
|
data_diff,
|
||
|
threshold,
|
||
|
f"p_data diff {data_diff}. failed check, step {step}, lr {lr} eps "
|
||
|
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} loss_scale {loss_scale} grad_dtype {grad_dtype}",
|
||
|
)
|
||
|
max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad))
|
||
|
self.assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}")
|
||
|
max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
|
||
|
self.assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
|
||
|
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
|
||
|
self.assertTrue(
|
||
|
max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}"
|
||
|
)
|
||
|
|
||
|
def test_cpu_adam(self):
|
||
|
lr = 0.9
|
||
|
eps = 1e-6
|
||
|
weight_decay = 0
|
||
|
for use_adamw in [False, True]:
|
||
|
for shape in [(1023, ), (32, 1024)]:
|
||
|
for step in range(1, 2):
|
||
|
for lr in [0.01]:
|
||
|
for eps in [1e-8]:
|
||
|
for beta1 in [0.9]:
|
||
|
for beta2 in [0.999]:
|
||
|
for weight_decay in [0.001]:
|
||
|
for grad_dtype in [torch.half, torch.float]:
|
||
|
for loss_scale in [-1, 2 ** 5]:
|
||
|
self.check_res(
|
||
|
step,
|
||
|
lr,
|
||
|
eps,
|
||
|
beta1,
|
||
|
beta2,
|
||
|
weight_decay,
|
||
|
shape,
|
||
|
grad_dtype,
|
||
|
loss_scale,
|
||
|
use_adamw,
|
||
|
cpu_adam,
|
||
|
)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
test = Test()
|
||
|
test.test_cpu_adam()
|
||
|
print('All is well.')
|