[zero] hybrid cpu adam (#445)

pull/448/head
Jiarui Fang 3 years ago committed by GitHub
parent b72b8445c6
commit 237d08e7ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,4 +1,5 @@
import torch import torch
import math
class CPUAdam(torch.optim.Optimizer): class CPUAdam(torch.optim.Optimizer):
@ -8,19 +9,18 @@ class CPUAdam(torch.optim.Optimizer):
model_params, model_params,
lr=1e-3, lr=1e-3,
bias_correction=True, bias_correction=True,
betas=(0.9, betas=(0.9, 0.999),
0.999),
eps=1e-8, eps=1e-8,
weight_decay=0, weight_decay=0,
adamw_mode=True, adamw_mode=True,
loss_scale=-1, loss_scale=-1,
simd_log=False): simd_log=False):
"""
default_args = dict(lr=lr, An implementation equivalent to `torch.optim.Adam`.
betas=betas, The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance.
eps=eps, The sharded param of model_params can resident on both CPU and CUDA.
weight_decay=weight_decay, """
bias_correction=bias_correction) default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args) super(CPUAdam, self).__init__(model_params, default_args)
self.opt_id = CPUAdam.optimizer_id self.opt_id = CPUAdam.optimizer_id
CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1 CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1
@ -31,18 +31,45 @@ class CPUAdam(torch.optim.Optimizer):
except ImportError: except ImportError:
raise ImportError('Please install colossalai from source code to use CPUAdam') raise ImportError('Please install colossalai from source code to use CPUAdam')
self.cpu_adam_op = cpu_adam self.cpu_adam_op = cpu_adam
self.cpu_adam_op.create_adam(self.opt_id, self.cpu_adam_op.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, simd_log)
lr,
betas[0],
betas[1],
eps,
weight_decay,
adamw_mode,
simd_log)
def __del__(self): def __del__(self):
self.cpu_adam_op.destroy_adam(self.opt_id) self.cpu_adam_op.destroy_adam(self.opt_id)
def torch_adam_update(self,
data,
grad,
exp_avg,
exp_avg_sq,
lr,
beta1,
beta2,
eps,
weight_decay,
bias_correction1,
bias_correction2,
loss_scale,
use_adamw=False):
if loss_scale is not None:
grad.div_(loss_scale)
if weight_decay != 0:
if use_adamw:
data.mul_(1 - lr * weight_decay)
else:
grad = grad.add(data, alpha=weight_decay)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# TODO(jiaruifang) dose not support amsgrad
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
data.addcdiv_(exp_avg, denom, value=-step_size)
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
@ -51,47 +78,47 @@ class CPUAdam(torch.optim.Optimizer):
with torch.enable_grad(): with torch.enable_grad():
loss = closure() loss = closure()
# intended device for step for _, group in enumerate(self.param_groups):
device = torch.device('cpu') for _, p in enumerate(group['params']):
for group_id, group in enumerate(self.param_groups):
for param_id, p in enumerate(group['params']):
if p.grad is None: if p.grad is None:
continue continue
assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \
"sure the cpu_offload is Ture"
state = self.state[p] state = self.state[p]
# State initialization
target_device = p.device
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state['step'] = 0
# gradient momentums # gradient momentums
state['exp_avg'] = torch.zeros_like(p.data, state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
dtype=torch.float,
device=device)
# gradient variances # gradient variances
state['exp_avg_sq'] = torch.zeros_like(p.data, state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
dtype=torch.float,
device=device)
# memory_format=torch.preserve_format)
state['step'] += 1 state['step'] += 1
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
self.cpu_adam_op.adam_update(self.opt_id, if target_device.type == 'cpu':
state['step'], assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
group['lr'], assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
beta1, self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
beta2, group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
group['eps'], state['exp_avg'], state['exp_avg_sq'], self.loss_scale)
group['weight_decay'], elif target_device.type == 'cuda':
group['bias_correction'], # FIXME() prepare grad on cuda
p.data, if p.grad.device.type == 'cpu':
p.grad.data, p.grad = p.grad.to(target_device)
state['exp_avg'],
state['exp_avg_sq'], assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
self.loss_scale) assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
bias_correction1 = 1 - beta1**state['step']
bias_correction2 = 1 - beta2**state['step']
# adam on cuda
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
beta1, beta2, group['eps'], group['weight_decay'], bias_correction1,
bias_correction2, self.loss_scale)
else:
raise RuntimeError
return loss return loss

@ -1,21 +1,20 @@
from asyncio.log import logger from typing import Callable
from distutils.command.config import config
import torch
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import TensorShardStrategy
import torch
import torch.nn as nn
from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from torch.optim import Optimizer
from .sharded_model import ShardedModel
from .sharded_optim import ShardedOptimizer
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from typing import Callable, Type
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from .sharded_model import ShardedModel
from .sharded_optim import ShardedOptimizer
def convert_to_zero_v2(model_builder: Callable, model_config, optimizer_config) -> (ShardedModelV2, ShardedOptimizerV2): def convert_to_zero_v2(model_builder: Callable, model_config, optimizer_config) -> (ShardedModelV2, ShardedOptimizerV2):
""" """

Loading…
Cancel
Save