[zero]added hybrid adam, removed loss scale in adam (#527)

* [zero]added hybrid adam, removed loss scale of adam

* remove useless code
pull/528/head^2
LuGY 2022-03-25 18:03:54 +08:00 committed by GitHub
parent 8d8c5407c0
commit 105c5301c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 149 additions and 27 deletions

View File

@ -17,7 +17,6 @@ class CPUAdam(torch.optim.Optimizer):
eps=1e-8,
weight_decay=0,
adamw_mode=True,
loss_scale=-1,
simd_log=False):
"""
An implementation equivalent to `torch.optim.Adam`.
@ -29,8 +28,7 @@ class CPUAdam(torch.optim.Optimizer):
super(CPUAdam, self).__init__(model_params, default_args)
self.opt_id = CPUAdam.optimizer_id
CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1
self.adam_w_mode = adamw_mode
self.loss_scale = loss_scale
self.adamw_mode = adamw_mode
try:
import cpu_adam
except ImportError:
@ -54,12 +52,9 @@ class CPUAdam(torch.optim.Optimizer):
weight_decay,
bias_correction1,
bias_correction2,
loss_scale,
use_adamw=False):
# FIXME(ver217): remove the below line when replace torch adam with fused adam
grad = grad.float()
if loss_scale is not None:
grad.div_(loss_scale)
if weight_decay != 0:
if use_adamw:
@ -110,7 +105,7 @@ class CPUAdam(torch.optim.Optimizer):
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'], self.loss_scale)
state['exp_avg'], state['exp_avg_sq'], -1)
elif target_device.type == 'cuda':
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
@ -121,7 +116,7 @@ class CPUAdam(torch.optim.Optimizer):
# adam on cuda
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
beta1, beta2, group['eps'], group['weight_decay'], bias_correction1,
bias_correction2, self.loss_scale)
bias_correction2, self.adamw_mode)
else:
raise RuntimeError
return loss

View File

@ -0,0 +1,101 @@
import torch
from colossalai.utils import multi_tensor_applier
class HybridAdam(torch.optim.Optimizer):
optimizer_id = 0
# Number of fp32 shards for per parameter
# Param weight, grad, momentum and variance
num_fp32_shards_per_param = 4
def __init__(self,
model_params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
adamw_mode=True,
simd_log=False):
"""
An implementation equivalent to `torch.optim.Adam`.
The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance.
The sharded param of model_params can resident on both CPU and CUDA(fused adam).
"""
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(HybridAdam, self).__init__(model_params, default_args)
self.opt_id = HybridAdam.optimizer_id
HybridAdam.optimizer_id = HybridAdam.optimizer_id + 1
self.adamw_mode = adamw_mode
try:
import cpu_adam
import colossal_C
except ImportError:
raise ImportError('Please install colossalai from source code to use HybridAdam')
self.cpu_adam_op = cpu_adam
self.cpu_adam_op.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, simd_log)
self.gpu_adam_op = colossal_C.multi_tensor_adam
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
def __del__(self):
if self.cpu_adam_op:
self.cpu_adam_op.destroy_adam(self.opt_id)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for _, group in enumerate(self.param_groups):
g_l, p_l, m_l, v_l = [], [], [], []
group_step = 0
for _, p in enumerate(group['params']):
if p.grad is None:
continue
state = self.state[p]
target_device = p.device
if len(state) == 0:
state['step'] = 0
# gradient momentums
state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
state['step'] += 1
group_step = state['step']
beta1, beta2 = group['betas']
if target_device.type == 'cpu':
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'], -1)
elif target_device.type == 'cuda':
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
# record the state by gruop and update at once
g_l.append(p.grad.data)
p_l.append(p.data)
m_l.append(state['exp_avg'])
v_l.append(state['exp_avg_sq'])
else:
raise RuntimeError
if len(g_l) > 0:
adamw_mode = 1 if self.adamw_mode else 0
bias_correction = 1 if group['bias_correction'] else 0
multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l,m_l, v_l],
group['lr'], group['betas'][0], group['betas'][1], group['eps'], group_step,
adamw_mode, bias_correction, group['weight_decay'])
return loss

View File

@ -15,11 +15,8 @@ def torch_adam_update(
grad,
exp_avg,
exp_avg_sq,
loss_scale,
use_adamw,
):
if loss_scale > 0:
grad.div_(loss_scale)
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
@ -50,10 +47,9 @@ def assertTrue(condition, msg):
@parameterize('adamw', [True, False])
@parameterize('step', [1, 2])
@parameterize('loss_scale', [-1, 2 ** 5])
@parameterize('p_dtype', [torch.float, torch.half])
@parameterize('g_dtype', [torch.float, torch.half])
def test_cpu_adam(adamw, step, loss_scale, p_dtype, g_dtype):
def test_cpu_adam(adamw, step, p_dtype, g_dtype):
lr = 1e-3
beta1, beta2 = 0.9, 0.999
eps = 1e-8
@ -63,8 +59,6 @@ def test_cpu_adam(adamw, step, loss_scale, p_dtype, g_dtype):
p_data = torch.rand(64, dtype=p_dtype)
p_data_copy = p_data.clone().float()
p_grad = torch.rand(64, dtype=g_dtype)
if loss_scale > 0:
p_grad.mul_(loss_scale)
p_grad_copy = p_grad.clone().float()
exp_avg = torch.rand(p_data.shape)
exp_avg_copy = exp_avg.clone()
@ -75,7 +69,7 @@ def test_cpu_adam(adamw, step, loss_scale, p_dtype, g_dtype):
import cpu_adam
cpu_adam_op = cpu_adam
except:
raise ImportError("...")
raise ImportError("Import cpu adam error, please install colossal from source code")
cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, adamw, False)
cpu_adam_op.adam_update(
@ -91,7 +85,7 @@ def test_cpu_adam(adamw, step, loss_scale, p_dtype, g_dtype):
p_grad.view(-1), # fp32 grad
exp_avg.view(-1),
exp_avg_sq.view(-1),
loss_scale,
-1,
)
torch_adam_update(
@ -105,20 +99,15 @@ def test_cpu_adam(adamw, step, loss_scale, p_dtype, g_dtype):
p_grad_copy, # fp32 grad
exp_avg_copy,
exp_avg_sq_copy,
loss_scale,
adamw,
)
if loss_scale > 0:
p_grad.div_(loss_scale)
var = p_data_copy - p_data
data_diff = torch.max(torch.abs(var))
threshold = 1e-3
print(f"p_data diff {data_diff}. failed check, step {step}, lr {lr} eps "
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}")
assertLess(
data_diff,
threshold,
f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, loss_scale {loss_scale}, eps "
f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, eps "
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}",
)
max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad))

View File

@ -18,11 +18,8 @@ def torch_adam_update(
grad,
exp_avg,
exp_avg_sq,
loss_scale,
use_adamw,
):
if loss_scale > 0:
grad.div_(loss_scale)
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
@ -87,7 +84,6 @@ def test_adam(adamw, step, p_dtype, g_dtype):
g_copy, # fp32 grad
m_copy,
v_copy,
-1,
adamw,
)

View File

@ -0,0 +1,41 @@
import torch
import torch.nn as nn
from torch.optim.adam import Adam
from torch.optim import AdamW
from colossalai.nn.optimizer.hybrid_adam import HybridAdam
from colossalai.testing import parameterize
RE = 1024
@parameterize('adamw', [False, True])
@parameterize('device', ['cpu', 'cuda:0'])
@parameterize('p_dtype', [torch.float])
@parameterize('g_dtype', [torch.float, torch.half])
def test_adam(adamw, device, p_dtype, g_dtype):
rng_state = torch.get_rng_state()
p = nn.Parameter(torch.rand(64).to(device, p_dtype))
torch.set_rng_state(rng_state)
p_copy = nn.Parameter(torch.rand(64).to(device).float())
if adamw:
optim = HybridAdam([p], lr=1e-3, adamw_mode=True)
torch_optim = AdamW([p_copy], lr=1e-3)
else:
optim = HybridAdam([p], lr=1e-3)
torch_optim = Adam([p_copy], lr=1e-3)
print(f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}")
for i in range(RE):
p.grad = torch.rand(64).to(device, p_dtype)
p_copy.grad = p.grad.clone().float()
p.grad.data = p.grad.data.to(g_dtype)
optim.step()
torch_optim.step()
if torch.isnan(p.data).any() or torch.isnan(p_copy.data).any():
continue
assert torch.allclose(p.data, p_copy.data, 1e-4, 1e-2), \
f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}"