mirror of https://github.com/hpcaitech/ColossalAI
[zero]added hybrid adam, removed loss scale in adam (#527)
* [zero]added hybrid adam, removed loss scale of adam * remove useless codepull/528/head^2
parent
8d8c5407c0
commit
105c5301c3
|
@ -17,7 +17,6 @@ class CPUAdam(torch.optim.Optimizer):
|
||||||
eps=1e-8,
|
eps=1e-8,
|
||||||
weight_decay=0,
|
weight_decay=0,
|
||||||
adamw_mode=True,
|
adamw_mode=True,
|
||||||
loss_scale=-1,
|
|
||||||
simd_log=False):
|
simd_log=False):
|
||||||
"""
|
"""
|
||||||
An implementation equivalent to `torch.optim.Adam`.
|
An implementation equivalent to `torch.optim.Adam`.
|
||||||
|
@ -29,8 +28,7 @@ class CPUAdam(torch.optim.Optimizer):
|
||||||
super(CPUAdam, self).__init__(model_params, default_args)
|
super(CPUAdam, self).__init__(model_params, default_args)
|
||||||
self.opt_id = CPUAdam.optimizer_id
|
self.opt_id = CPUAdam.optimizer_id
|
||||||
CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1
|
CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1
|
||||||
self.adam_w_mode = adamw_mode
|
self.adamw_mode = adamw_mode
|
||||||
self.loss_scale = loss_scale
|
|
||||||
try:
|
try:
|
||||||
import cpu_adam
|
import cpu_adam
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -54,12 +52,9 @@ class CPUAdam(torch.optim.Optimizer):
|
||||||
weight_decay,
|
weight_decay,
|
||||||
bias_correction1,
|
bias_correction1,
|
||||||
bias_correction2,
|
bias_correction2,
|
||||||
loss_scale,
|
|
||||||
use_adamw=False):
|
use_adamw=False):
|
||||||
# FIXME(ver217): remove the below line when replace torch adam with fused adam
|
# FIXME(ver217): remove the below line when replace torch adam with fused adam
|
||||||
grad = grad.float()
|
grad = grad.float()
|
||||||
if loss_scale is not None:
|
|
||||||
grad.div_(loss_scale)
|
|
||||||
|
|
||||||
if weight_decay != 0:
|
if weight_decay != 0:
|
||||||
if use_adamw:
|
if use_adamw:
|
||||||
|
@ -110,7 +105,7 @@ class CPUAdam(torch.optim.Optimizer):
|
||||||
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||||
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
|
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||||
state['exp_avg'], state['exp_avg_sq'], self.loss_scale)
|
state['exp_avg'], state['exp_avg_sq'], -1)
|
||||||
elif target_device.type == 'cuda':
|
elif target_device.type == 'cuda':
|
||||||
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
|
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||||
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
|
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||||
|
@ -121,7 +116,7 @@ class CPUAdam(torch.optim.Optimizer):
|
||||||
# adam on cuda
|
# adam on cuda
|
||||||
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
|
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
|
||||||
beta1, beta2, group['eps'], group['weight_decay'], bias_correction1,
|
beta1, beta2, group['eps'], group['weight_decay'], bias_correction1,
|
||||||
bias_correction2, self.loss_scale)
|
bias_correction2, self.adamw_mode)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
return loss
|
return loss
|
||||||
|
|
|
@ -0,0 +1,101 @@
|
||||||
|
import torch
|
||||||
|
from colossalai.utils import multi_tensor_applier
|
||||||
|
|
||||||
|
class HybridAdam(torch.optim.Optimizer):
|
||||||
|
optimizer_id = 0
|
||||||
|
# Number of fp32 shards for per parameter
|
||||||
|
# Param weight, grad, momentum and variance
|
||||||
|
num_fp32_shards_per_param = 4
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_params,
|
||||||
|
lr=1e-3,
|
||||||
|
bias_correction=True,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
eps=1e-8,
|
||||||
|
weight_decay=0,
|
||||||
|
adamw_mode=True,
|
||||||
|
simd_log=False):
|
||||||
|
"""
|
||||||
|
An implementation equivalent to `torch.optim.Adam`.
|
||||||
|
The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance.
|
||||||
|
The sharded param of model_params can resident on both CPU and CUDA(fused adam).
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
||||||
|
super(HybridAdam, self).__init__(model_params, default_args)
|
||||||
|
self.opt_id = HybridAdam.optimizer_id
|
||||||
|
HybridAdam.optimizer_id = HybridAdam.optimizer_id + 1
|
||||||
|
self.adamw_mode = adamw_mode
|
||||||
|
try:
|
||||||
|
import cpu_adam
|
||||||
|
import colossal_C
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError('Please install colossalai from source code to use HybridAdam')
|
||||||
|
|
||||||
|
self.cpu_adam_op = cpu_adam
|
||||||
|
self.cpu_adam_op.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, simd_log)
|
||||||
|
|
||||||
|
self.gpu_adam_op = colossal_C.multi_tensor_adam
|
||||||
|
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.cpu_adam_op:
|
||||||
|
self.cpu_adam_op.destroy_adam(self.opt_id)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for _, group in enumerate(self.param_groups):
|
||||||
|
g_l, p_l, m_l, v_l = [], [], [], []
|
||||||
|
group_step = 0
|
||||||
|
for _, p in enumerate(group['params']):
|
||||||
|
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
|
target_device = p.device
|
||||||
|
if len(state) == 0:
|
||||||
|
state['step'] = 0
|
||||||
|
|
||||||
|
# gradient momentums
|
||||||
|
state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
|
||||||
|
# gradient variances
|
||||||
|
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
|
||||||
|
|
||||||
|
state['step'] += 1
|
||||||
|
group_step = state['step']
|
||||||
|
beta1, beta2 = group['betas']
|
||||||
|
|
||||||
|
if target_device.type == 'cpu':
|
||||||
|
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||||
|
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||||
|
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||||
|
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||||
|
state['exp_avg'], state['exp_avg_sq'], -1)
|
||||||
|
|
||||||
|
elif target_device.type == 'cuda':
|
||||||
|
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||||
|
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||||
|
|
||||||
|
# record the state by gruop and update at once
|
||||||
|
g_l.append(p.grad.data)
|
||||||
|
p_l.append(p.data)
|
||||||
|
m_l.append(state['exp_avg'])
|
||||||
|
v_l.append(state['exp_avg_sq'])
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
if len(g_l) > 0:
|
||||||
|
adamw_mode = 1 if self.adamw_mode else 0
|
||||||
|
bias_correction = 1 if group['bias_correction'] else 0
|
||||||
|
multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l,m_l, v_l],
|
||||||
|
group['lr'], group['betas'][0], group['betas'][1], group['eps'], group_step,
|
||||||
|
adamw_mode, bias_correction, group['weight_decay'])
|
||||||
|
return loss
|
|
@ -15,11 +15,8 @@ def torch_adam_update(
|
||||||
grad,
|
grad,
|
||||||
exp_avg,
|
exp_avg,
|
||||||
exp_avg_sq,
|
exp_avg_sq,
|
||||||
loss_scale,
|
|
||||||
use_adamw,
|
use_adamw,
|
||||||
):
|
):
|
||||||
if loss_scale > 0:
|
|
||||||
grad.div_(loss_scale)
|
|
||||||
bias_correction1 = 1 - beta1**step
|
bias_correction1 = 1 - beta1**step
|
||||||
bias_correction2 = 1 - beta2**step
|
bias_correction2 = 1 - beta2**step
|
||||||
|
|
||||||
|
@ -50,10 +47,9 @@ def assertTrue(condition, msg):
|
||||||
|
|
||||||
@parameterize('adamw', [True, False])
|
@parameterize('adamw', [True, False])
|
||||||
@parameterize('step', [1, 2])
|
@parameterize('step', [1, 2])
|
||||||
@parameterize('loss_scale', [-1, 2 ** 5])
|
|
||||||
@parameterize('p_dtype', [torch.float, torch.half])
|
@parameterize('p_dtype', [torch.float, torch.half])
|
||||||
@parameterize('g_dtype', [torch.float, torch.half])
|
@parameterize('g_dtype', [torch.float, torch.half])
|
||||||
def test_cpu_adam(adamw, step, loss_scale, p_dtype, g_dtype):
|
def test_cpu_adam(adamw, step, p_dtype, g_dtype):
|
||||||
lr = 1e-3
|
lr = 1e-3
|
||||||
beta1, beta2 = 0.9, 0.999
|
beta1, beta2 = 0.9, 0.999
|
||||||
eps = 1e-8
|
eps = 1e-8
|
||||||
|
@ -63,8 +59,6 @@ def test_cpu_adam(adamw, step, loss_scale, p_dtype, g_dtype):
|
||||||
p_data = torch.rand(64, dtype=p_dtype)
|
p_data = torch.rand(64, dtype=p_dtype)
|
||||||
p_data_copy = p_data.clone().float()
|
p_data_copy = p_data.clone().float()
|
||||||
p_grad = torch.rand(64, dtype=g_dtype)
|
p_grad = torch.rand(64, dtype=g_dtype)
|
||||||
if loss_scale > 0:
|
|
||||||
p_grad.mul_(loss_scale)
|
|
||||||
p_grad_copy = p_grad.clone().float()
|
p_grad_copy = p_grad.clone().float()
|
||||||
exp_avg = torch.rand(p_data.shape)
|
exp_avg = torch.rand(p_data.shape)
|
||||||
exp_avg_copy = exp_avg.clone()
|
exp_avg_copy = exp_avg.clone()
|
||||||
|
@ -75,7 +69,7 @@ def test_cpu_adam(adamw, step, loss_scale, p_dtype, g_dtype):
|
||||||
import cpu_adam
|
import cpu_adam
|
||||||
cpu_adam_op = cpu_adam
|
cpu_adam_op = cpu_adam
|
||||||
except:
|
except:
|
||||||
raise ImportError("...")
|
raise ImportError("Import cpu adam error, please install colossal from source code")
|
||||||
|
|
||||||
cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, adamw, False)
|
cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, adamw, False)
|
||||||
cpu_adam_op.adam_update(
|
cpu_adam_op.adam_update(
|
||||||
|
@ -91,7 +85,7 @@ def test_cpu_adam(adamw, step, loss_scale, p_dtype, g_dtype):
|
||||||
p_grad.view(-1), # fp32 grad
|
p_grad.view(-1), # fp32 grad
|
||||||
exp_avg.view(-1),
|
exp_avg.view(-1),
|
||||||
exp_avg_sq.view(-1),
|
exp_avg_sq.view(-1),
|
||||||
loss_scale,
|
-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch_adam_update(
|
torch_adam_update(
|
||||||
|
@ -105,20 +99,15 @@ def test_cpu_adam(adamw, step, loss_scale, p_dtype, g_dtype):
|
||||||
p_grad_copy, # fp32 grad
|
p_grad_copy, # fp32 grad
|
||||||
exp_avg_copy,
|
exp_avg_copy,
|
||||||
exp_avg_sq_copy,
|
exp_avg_sq_copy,
|
||||||
loss_scale,
|
|
||||||
adamw,
|
adamw,
|
||||||
)
|
)
|
||||||
if loss_scale > 0:
|
|
||||||
p_grad.div_(loss_scale)
|
|
||||||
var = p_data_copy - p_data
|
var = p_data_copy - p_data
|
||||||
data_diff = torch.max(torch.abs(var))
|
data_diff = torch.max(torch.abs(var))
|
||||||
threshold = 1e-3
|
threshold = 1e-3
|
||||||
print(f"p_data diff {data_diff}. failed check, step {step}, lr {lr} eps "
|
|
||||||
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}")
|
|
||||||
assertLess(
|
assertLess(
|
||||||
data_diff,
|
data_diff,
|
||||||
threshold,
|
threshold,
|
||||||
f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, loss_scale {loss_scale}, eps "
|
f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, eps "
|
||||||
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}",
|
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}",
|
||||||
)
|
)
|
||||||
max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad))
|
max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad))
|
||||||
|
|
|
@ -18,11 +18,8 @@ def torch_adam_update(
|
||||||
grad,
|
grad,
|
||||||
exp_avg,
|
exp_avg,
|
||||||
exp_avg_sq,
|
exp_avg_sq,
|
||||||
loss_scale,
|
|
||||||
use_adamw,
|
use_adamw,
|
||||||
):
|
):
|
||||||
if loss_scale > 0:
|
|
||||||
grad.div_(loss_scale)
|
|
||||||
bias_correction1 = 1 - beta1**step
|
bias_correction1 = 1 - beta1**step
|
||||||
bias_correction2 = 1 - beta2**step
|
bias_correction2 = 1 - beta2**step
|
||||||
|
|
||||||
|
@ -87,7 +84,6 @@ def test_adam(adamw, step, p_dtype, g_dtype):
|
||||||
g_copy, # fp32 grad
|
g_copy, # fp32 grad
|
||||||
m_copy,
|
m_copy,
|
||||||
v_copy,
|
v_copy,
|
||||||
-1,
|
|
||||||
adamw,
|
adamw,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim.adam import Adam
|
||||||
|
from torch.optim import AdamW
|
||||||
|
|
||||||
|
from colossalai.nn.optimizer.hybrid_adam import HybridAdam
|
||||||
|
from colossalai.testing import parameterize
|
||||||
|
|
||||||
|
RE = 1024
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('adamw', [False, True])
|
||||||
|
@parameterize('device', ['cpu', 'cuda:0'])
|
||||||
|
@parameterize('p_dtype', [torch.float])
|
||||||
|
@parameterize('g_dtype', [torch.float, torch.half])
|
||||||
|
def test_adam(adamw, device, p_dtype, g_dtype):
|
||||||
|
rng_state = torch.get_rng_state()
|
||||||
|
p = nn.Parameter(torch.rand(64).to(device, p_dtype))
|
||||||
|
torch.set_rng_state(rng_state)
|
||||||
|
p_copy = nn.Parameter(torch.rand(64).to(device).float())
|
||||||
|
|
||||||
|
if adamw:
|
||||||
|
optim = HybridAdam([p], lr=1e-3, adamw_mode=True)
|
||||||
|
torch_optim = AdamW([p_copy], lr=1e-3)
|
||||||
|
else:
|
||||||
|
optim = HybridAdam([p], lr=1e-3)
|
||||||
|
torch_optim = Adam([p_copy], lr=1e-3)
|
||||||
|
|
||||||
|
print(f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}")
|
||||||
|
for i in range(RE):
|
||||||
|
p.grad = torch.rand(64).to(device, p_dtype)
|
||||||
|
p_copy.grad = p.grad.clone().float()
|
||||||
|
p.grad.data = p.grad.data.to(g_dtype)
|
||||||
|
|
||||||
|
optim.step()
|
||||||
|
torch_optim.step()
|
||||||
|
|
||||||
|
if torch.isnan(p.data).any() or torch.isnan(p_copy.data).any():
|
||||||
|
continue
|
||||||
|
assert torch.allclose(p.data, p_copy.data, 1e-4, 1e-2), \
|
||||||
|
f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
Loading…
Reference in New Issue