You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/nn/optimizer/hybrid_adam.py

102 lines
4.2 KiB

import torch
from colossalai.utils import multi_tensor_applier
class HybridAdam(torch.optim.Optimizer):
optimizer_id = 0
# Number of fp32 shards for per parameter
# Param weight, grad, momentum and variance
num_fp32_shards_per_param = 4
def __init__(self,
model_params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
adamw_mode=True,
simd_log=False):
"""
An implementation equivalent to `torch.optim.Adam`.
The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance.
The sharded param of model_params can resident on both CPU and CUDA(fused adam).
"""
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(HybridAdam, self).__init__(model_params, default_args)
self.opt_id = HybridAdam.optimizer_id
HybridAdam.optimizer_id = HybridAdam.optimizer_id + 1
self.adamw_mode = adamw_mode
try:
import cpu_adam
import colossal_C
except ImportError:
raise ImportError('Please install colossalai from source code to use HybridAdam')
self.cpu_adam_op = cpu_adam
self.cpu_adam_op.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, simd_log)
self.gpu_adam_op = colossal_C.multi_tensor_adam
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
def __del__(self):
if self.cpu_adam_op:
self.cpu_adam_op.destroy_adam(self.opt_id)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for _, group in enumerate(self.param_groups):
g_l, p_l, m_l, v_l = [], [], [], []
group_step = 0
for _, p in enumerate(group['params']):
if p.grad is None:
continue
state = self.state[p]
target_device = p.device
if len(state) == 0:
state['step'] = 0
# gradient momentums
state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
state['step'] += 1
group_step = state['step']
beta1, beta2 = group['betas']
if target_device.type == 'cpu':
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'], -1)
elif target_device.type == 'cuda':
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
# record the state by gruop and update at once
g_l.append(p.grad.data)
p_l.append(p.data)
m_l.append(state['exp_avg'])
v_l.append(state['exp_avg_sq'])
else:
raise RuntimeError
if len(g_l) > 0:
adamw_mode = 1 if self.adamw_mode else 0
bias_correction = 1 if group['bias_correction'] else 0
multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l,m_l, v_l],
group['lr'], group['betas'][0], group['betas'][1], group['eps'], group_step,
adamw_mode, bias_correction, group['weight_decay'])
return loss