mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] run cpu adam unittest in pytest (#424)
parent
54229cd33e
commit
5d7dc3525b
|
@ -29,12 +29,12 @@
|
|||
|
||||
import math
|
||||
import torch
|
||||
import colossalai
|
||||
try:
|
||||
import cpu_adam
|
||||
except ImportError:
|
||||
raise ImportError("import cpu_adam error")
|
||||
|
||||
|
||||
def torch_adam_update(
|
||||
step,
|
||||
lr,
|
||||
|
@ -42,7 +42,6 @@ def torch_adam_update(
|
|||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
bias_correction,
|
||||
param,
|
||||
grad,
|
||||
exp_avg,
|
||||
|
@ -52,8 +51,8 @@ def torch_adam_update(
|
|||
):
|
||||
if loss_scale > 0:
|
||||
grad.div_(loss_scale)
|
||||
bias_correction1 = 1 - beta1 ** step
|
||||
bias_correction2 = 1 - beta2 ** step
|
||||
bias_correction1 = 1 - beta1**step
|
||||
bias_correction2 = 1 - beta2**step
|
||||
|
||||
if weight_decay != 0:
|
||||
if use_adamw:
|
||||
|
@ -73,12 +72,13 @@ def torch_adam_update(
|
|||
|
||||
|
||||
class Test():
|
||||
|
||||
def __init__(self):
|
||||
self.opt_id = 0
|
||||
|
||||
|
||||
def assertLess(self, data_diff, threshold, msg):
|
||||
assert data_diff < threshold, msg
|
||||
|
||||
|
||||
def assertTrue(self, condition, msg):
|
||||
assert condition, msg
|
||||
|
||||
|
@ -89,7 +89,6 @@ class Test():
|
|||
eps,
|
||||
beta1,
|
||||
beta2,
|
||||
|
||||
weight_decay,
|
||||
shape,
|
||||
grad_dtype,
|
||||
|
@ -118,8 +117,8 @@ class Test():
|
|||
eps,
|
||||
weight_decay,
|
||||
True,
|
||||
p_data.view(-1), # fp32 data
|
||||
p_grad.view(-1), # fp32 grad
|
||||
p_data.view(-1), # fp32 data
|
||||
p_grad.view(-1), # fp32 grad
|
||||
exp_avg.view(-1),
|
||||
exp_avg_sq.view(-1),
|
||||
loss_scale,
|
||||
|
@ -132,15 +131,14 @@ class Test():
|
|||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
True,
|
||||
p_data_copy, # fp32 data
|
||||
p_grad_copy, # fp32 grad
|
||||
p_data_copy, # fp32 data
|
||||
p_grad_copy, # fp32 grad
|
||||
exp_avg_copy,
|
||||
exp_avg_sq_copy,
|
||||
loss_scale,
|
||||
use_adamw,
|
||||
)
|
||||
|
||||
|
||||
if loss_scale > 0:
|
||||
p_grad.div_(loss_scale)
|
||||
|
||||
|
@ -158,16 +156,14 @@ class Test():
|
|||
max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
|
||||
self.assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
|
||||
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
|
||||
self.assertTrue(
|
||||
max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}"
|
||||
)
|
||||
self.assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
|
||||
|
||||
def test_cpu_adam(self):
|
||||
lr = 0.9
|
||||
eps = 1e-6
|
||||
weight_decay = 0
|
||||
for use_adamw in [False, True]:
|
||||
for shape in [(1023, ), (32, 1024)]:
|
||||
for shape in [(23,), (8, 24)]:
|
||||
for step in range(1, 2):
|
||||
for lr in [0.01]:
|
||||
for eps in [1e-8]:
|
||||
|
@ -175,7 +171,7 @@ class Test():
|
|||
for beta2 in [0.999]:
|
||||
for weight_decay in [0.001]:
|
||||
for grad_dtype in [torch.half, torch.float]:
|
||||
for loss_scale in [-1, 2 ** 5]:
|
||||
for loss_scale in [-1, 2**5]:
|
||||
self.check_res(
|
||||
step,
|
||||
lr,
|
||||
|
@ -191,7 +187,11 @@ class Test():
|
|||
)
|
||||
|
||||
|
||||
def test_cpu_adam():
|
||||
test_case = Test()
|
||||
test_case.test_cpu_adam()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test = Test()
|
||||
test.test_cpu_adam()
|
||||
print('All is well.')
|
||||
|
|
Loading…
Reference in New Issue