diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 55f50bfdc..0b078dc88 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -1,4 +1,5 @@ import torch +import math class CPUAdam(torch.optim.Optimizer): @@ -8,19 +9,18 @@ class CPUAdam(torch.optim.Optimizer): model_params, lr=1e-3, bias_correction=True, - betas=(0.9, - 0.999), + betas=(0.9, 0.999), eps=1e-8, weight_decay=0, adamw_mode=True, loss_scale=-1, simd_log=False): - - default_args = dict(lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - bias_correction=bias_correction) + """ + An implementation equivalent to `torch.optim.Adam`. + The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance. + The sharded param of model_params can resident on both CPU and CUDA. + """ + default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(CPUAdam, self).__init__(model_params, default_args) self.opt_id = CPUAdam.optimizer_id CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1 @@ -31,18 +31,45 @@ class CPUAdam(torch.optim.Optimizer): except ImportError: raise ImportError('Please install colossalai from source code to use CPUAdam') self.cpu_adam_op = cpu_adam - self.cpu_adam_op.create_adam(self.opt_id, - lr, - betas[0], - betas[1], - eps, - weight_decay, - adamw_mode, - simd_log) + self.cpu_adam_op.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, simd_log) def __del__(self): self.cpu_adam_op.destroy_adam(self.opt_id) + def torch_adam_update(self, + data, + grad, + exp_avg, + exp_avg_sq, + lr, + beta1, + beta2, + eps, + weight_decay, + bias_correction1, + bias_correction2, + loss_scale, + use_adamw=False): + if loss_scale is not None: + grad.div_(loss_scale) + + if weight_decay != 0: + if use_adamw: + data.mul_(1 - lr * weight_decay) + else: + grad = grad.add(data, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # TODO(jiaruifang) dose not support amsgrad + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + step_size = lr / bias_correction1 + + data.addcdiv_(exp_avg, denom, value=-step_size) + @torch.no_grad() def step(self, closure=None): @@ -51,47 +78,47 @@ class CPUAdam(torch.optim.Optimizer): with torch.enable_grad(): loss = closure() - # intended device for step - device = torch.device('cpu') - - for group_id, group in enumerate(self.param_groups): - for param_id, p in enumerate(group['params']): + for _, group in enumerate(self.param_groups): + for _, p in enumerate(group['params']): if p.grad is None: continue - assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \ - "sure the cpu_offload is Ture" - state = self.state[p] - # State initialization + + target_device = p.device if len(state) == 0: state['step'] = 0 # gradient momentums - state['exp_avg'] = torch.zeros_like(p.data, - dtype=torch.float, - device=device) + state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device) # gradient variances - state['exp_avg_sq'] = torch.zeros_like(p.data, - dtype=torch.float, - device=device) - # memory_format=torch.preserve_format) + state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device) state['step'] += 1 beta1, beta2 = group['betas'] - self.cpu_adam_op.adam_update(self.opt_id, - state['step'], - group['lr'], - beta1, - beta2, - group['eps'], - group['weight_decay'], - group['bias_correction'], - p.data, - p.grad.data, - state['exp_avg'], - state['exp_avg_sq'], - self.loss_scale) + if target_device.type == 'cpu': + assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" + assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" + self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'], + group['weight_decay'], group['bias_correction'], p.data, p.grad.data, + state['exp_avg'], state['exp_avg_sq'], self.loss_scale) + elif target_device.type == 'cuda': + # FIXME() prepare grad on cuda + if p.grad.device.type == 'cpu': + p.grad = p.grad.to(target_device) + + assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" + assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" + + bias_correction1 = 1 - beta1**state['step'] + bias_correction2 = 1 - beta2**state['step'] + + # adam on cuda + self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], + beta1, beta2, group['eps'], group['weight_decay'], bias_correction1, + bias_correction2, self.loss_scale) + else: + raise RuntimeError return loss diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 1d02c09bb..dc5814b66 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -1,21 +1,20 @@ -from asyncio.log import logger -from distutils.command.config import config +from typing import Callable + +import torch +import torch.nn as nn +from torch.optim import Optimizer + from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 from colossalai.zero.shard_utils import TensorShardStrategy -import torch -import torch.nn as nn from colossalai.amp.naive_amp import NaiveAMPModel -from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from torch.optim import Optimizer -from .sharded_model import ShardedModel -from .sharded_optim import ShardedOptimizer from colossalai.zero.init_ctx import ZeroInitContext -from typing import Callable, Type -from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger +from .sharded_model import ShardedModel +from .sharded_optim import ShardedOptimizer + def convert_to_zero_v2(model_builder: Callable, model_config, optimizer_config) -> (ShardedModelV2, ShardedOptimizerV2): """