[zero] hybrid cpu adam (#445)

pull/448/head
Jiarui Fang 3 years ago committed by GitHub
parent b72b8445c6
commit 237d08e7ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,4 +1,5 @@
import torch
import math
class CPUAdam(torch.optim.Optimizer):
@ -8,19 +9,18 @@ class CPUAdam(torch.optim.Optimizer):
model_params,
lr=1e-3,
bias_correction=True,
betas=(0.9,
0.999),
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
adamw_mode=True,
loss_scale=-1,
simd_log=False):
default_args = dict(lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
bias_correction=bias_correction)
"""
An implementation equivalent to `torch.optim.Adam`.
The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance.
The sharded param of model_params can resident on both CPU and CUDA.
"""
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args)
self.opt_id = CPUAdam.optimizer_id
CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1
@ -31,18 +31,45 @@ class CPUAdam(torch.optim.Optimizer):
except ImportError:
raise ImportError('Please install colossalai from source code to use CPUAdam')
self.cpu_adam_op = cpu_adam
self.cpu_adam_op.create_adam(self.opt_id,
lr,
betas[0],
betas[1],
eps,
weight_decay,
adamw_mode,
simd_log)
self.cpu_adam_op.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, simd_log)
def __del__(self):
self.cpu_adam_op.destroy_adam(self.opt_id)
def torch_adam_update(self,
data,
grad,
exp_avg,
exp_avg_sq,
lr,
beta1,
beta2,
eps,
weight_decay,
bias_correction1,
bias_correction2,
loss_scale,
use_adamw=False):
if loss_scale is not None:
grad.div_(loss_scale)
if weight_decay != 0:
if use_adamw:
data.mul_(1 - lr * weight_decay)
else:
grad = grad.add(data, alpha=weight_decay)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# TODO(jiaruifang) dose not support amsgrad
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
data.addcdiv_(exp_avg, denom, value=-step_size)
@torch.no_grad()
def step(self, closure=None):
@ -51,47 +78,47 @@ class CPUAdam(torch.optim.Optimizer):
with torch.enable_grad():
loss = closure()
# intended device for step
device = torch.device('cpu')
for group_id, group in enumerate(self.param_groups):
for param_id, p in enumerate(group['params']):
for _, group in enumerate(self.param_groups):
for _, p in enumerate(group['params']):
if p.grad is None:
continue
assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \
"sure the cpu_offload is Ture"
state = self.state[p]
# State initialization
target_device = p.device
if len(state) == 0:
state['step'] = 0
# gradient momentums
state['exp_avg'] = torch.zeros_like(p.data,
dtype=torch.float,
device=device)
state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p.data,
dtype=torch.float,
device=device)
# memory_format=torch.preserve_format)
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
state['step'] += 1
beta1, beta2 = group['betas']
self.cpu_adam_op.adam_update(self.opt_id,
state['step'],
group['lr'],
beta1,
beta2,
group['eps'],
group['weight_decay'],
group['bias_correction'],
p.data,
p.grad.data,
state['exp_avg'],
state['exp_avg_sq'],
self.loss_scale)
if target_device.type == 'cpu':
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'], self.loss_scale)
elif target_device.type == 'cuda':
# FIXME() prepare grad on cuda
if p.grad.device.type == 'cpu':
p.grad = p.grad.to(target_device)
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
bias_correction1 = 1 - beta1**state['step']
bias_correction2 = 1 - beta2**state['step']
# adam on cuda
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
beta1, beta2, group['eps'], group['weight_decay'], bias_correction1,
bias_correction2, self.loss_scale)
else:
raise RuntimeError
return loss

@ -1,21 +1,20 @@
from asyncio.log import logger
from distutils.command.config import config
from typing import Callable
import torch
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
from colossalai.zero.shard_utils import TensorShardStrategy
import torch
import torch.nn as nn
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from torch.optim import Optimizer
from .sharded_model import ShardedModel
from .sharded_optim import ShardedOptimizer
from colossalai.zero.init_ctx import ZeroInitContext
from typing import Callable, Type
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from .sharded_model import ShardedModel
from .sharded_optim import ShardedOptimizer
def convert_to_zero_v2(model_builder: Callable, model_config, optimizer_config) -> (ShardedModelV2, ShardedOptimizerV2):
"""

Loading…
Cancel
Save