[hotfix] run cpu adam unittest in pytest (#424)

pull/429/head
Jiarui Fang 2022-03-16 10:39:55 +08:00 committed by GitHub
parent 54229cd33e
commit 5d7dc3525b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 19 additions and 19 deletions

View File

@ -29,12 +29,12 @@
import math
import torch
import colossalai
try:
import cpu_adam
except ImportError:
raise ImportError("import cpu_adam error")
def torch_adam_update(
step,
lr,
@ -42,7 +42,6 @@ def torch_adam_update(
beta2,
eps,
weight_decay,
bias_correction,
param,
grad,
exp_avg,
@ -52,8 +51,8 @@ def torch_adam_update(
):
if loss_scale > 0:
grad.div_(loss_scale)
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
if weight_decay != 0:
if use_adamw:
@ -73,12 +72,13 @@ def torch_adam_update(
class Test():
def __init__(self):
self.opt_id = 0
def assertLess(self, data_diff, threshold, msg):
assert data_diff < threshold, msg
def assertTrue(self, condition, msg):
assert condition, msg
@ -89,7 +89,6 @@ class Test():
eps,
beta1,
beta2,
weight_decay,
shape,
grad_dtype,
@ -118,8 +117,8 @@ class Test():
eps,
weight_decay,
True,
p_data.view(-1), # fp32 data
p_grad.view(-1), # fp32 grad
p_data.view(-1), # fp32 data
p_grad.view(-1), # fp32 grad
exp_avg.view(-1),
exp_avg_sq.view(-1),
loss_scale,
@ -132,15 +131,14 @@ class Test():
beta2,
eps,
weight_decay,
True,
p_data_copy, # fp32 data
p_grad_copy, # fp32 grad
p_data_copy, # fp32 data
p_grad_copy, # fp32 grad
exp_avg_copy,
exp_avg_sq_copy,
loss_scale,
use_adamw,
)
if loss_scale > 0:
p_grad.div_(loss_scale)
@ -158,16 +156,14 @@ class Test():
max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
self.assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
self.assertTrue(
max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}"
)
self.assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
def test_cpu_adam(self):
lr = 0.9
eps = 1e-6
weight_decay = 0
for use_adamw in [False, True]:
for shape in [(1023, ), (32, 1024)]:
for shape in [(23,), (8, 24)]:
for step in range(1, 2):
for lr in [0.01]:
for eps in [1e-8]:
@ -175,7 +171,7 @@ class Test():
for beta2 in [0.999]:
for weight_decay in [0.001]:
for grad_dtype in [torch.half, torch.float]:
for loss_scale in [-1, 2 ** 5]:
for loss_scale in [-1, 2**5]:
self.check_res(
step,
lr,
@ -191,7 +187,11 @@ class Test():
)
def test_cpu_adam():
test_case = Test()
test_case.test_cpu_adam()
if __name__ == "__main__":
test = Test()
test.test_cpu_adam()
print('All is well.')