mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
102 lines
4.2 KiB
102 lines
4.2 KiB
import torch
|
|
from colossalai.utils import multi_tensor_applier
|
|
|
|
class HybridAdam(torch.optim.Optimizer):
|
|
optimizer_id = 0
|
|
# Number of fp32 shards for per parameter
|
|
# Param weight, grad, momentum and variance
|
|
num_fp32_shards_per_param = 4
|
|
|
|
def __init__(self,
|
|
model_params,
|
|
lr=1e-3,
|
|
bias_correction=True,
|
|
betas=(0.9, 0.999),
|
|
eps=1e-8,
|
|
weight_decay=0,
|
|
adamw_mode=True,
|
|
simd_log=False):
|
|
"""
|
|
An implementation equivalent to `torch.optim.Adam`.
|
|
The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance.
|
|
The sharded param of model_params can resident on both CPU and CUDA(fused adam).
|
|
"""
|
|
|
|
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
|
super(HybridAdam, self).__init__(model_params, default_args)
|
|
self.opt_id = HybridAdam.optimizer_id
|
|
HybridAdam.optimizer_id = HybridAdam.optimizer_id + 1
|
|
self.adamw_mode = adamw_mode
|
|
try:
|
|
import cpu_adam
|
|
import colossal_C
|
|
except ImportError:
|
|
raise ImportError('Please install colossalai from source code to use HybridAdam')
|
|
|
|
self.cpu_adam_op = cpu_adam
|
|
self.cpu_adam_op.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, simd_log)
|
|
|
|
self.gpu_adam_op = colossal_C.multi_tensor_adam
|
|
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
|
|
|
def __del__(self):
|
|
if self.cpu_adam_op:
|
|
self.cpu_adam_op.destroy_adam(self.opt_id)
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure=None):
|
|
loss = None
|
|
if closure is not None:
|
|
with torch.enable_grad():
|
|
loss = closure()
|
|
|
|
for _, group in enumerate(self.param_groups):
|
|
g_l, p_l, m_l, v_l = [], [], [], []
|
|
group_step = 0
|
|
for _, p in enumerate(group['params']):
|
|
|
|
if p.grad is None:
|
|
continue
|
|
|
|
state = self.state[p]
|
|
|
|
target_device = p.device
|
|
if len(state) == 0:
|
|
state['step'] = 0
|
|
|
|
# gradient momentums
|
|
state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
|
|
# gradient variances
|
|
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
|
|
|
|
state['step'] += 1
|
|
group_step = state['step']
|
|
beta1, beta2 = group['betas']
|
|
|
|
if target_device.type == 'cpu':
|
|
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
|
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
|
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
|
|
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
|
state['exp_avg'], state['exp_avg_sq'], -1)
|
|
|
|
elif target_device.type == 'cuda':
|
|
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
|
|
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
|
|
|
|
# record the state by gruop and update at once
|
|
g_l.append(p.grad.data)
|
|
p_l.append(p.data)
|
|
m_l.append(state['exp_avg'])
|
|
v_l.append(state['exp_avg_sq'])
|
|
|
|
else:
|
|
raise RuntimeError
|
|
if len(g_l) > 0:
|
|
adamw_mode = 1 if self.adamw_mode else 0
|
|
bias_correction = 1 if group['bias_correction'] else 0
|
|
multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l,m_l, v_l],
|
|
group['lr'], group['betas'][0], group['betas'][1], group['eps'], group_step,
|
|
adamw_mode, bias_correction, group['weight_decay'])
|
|
return loss
|