mirror of https://github.com/hpcaitech/ColossalAI
[nvme] CPUAdam and HybridAdam support NVMe offload (#1360)
* impl nvme optimizer * update cpu adam * add unit test * update hybrid adam * update docstr * add TODOs * update CI * fix CI * fix CI * fix CI path * fix CI path * fix CI path * fix install tensornvme * fix CI * fix CI path * fix CI env variables * test CI * test CI * fix CI * fix nvme optim __del__ * fix adam __del__ * fix nvme optim * fix CI env variables * fix nvme optim import * test CI * test CI * fix CIpull/1378/head
parent
8463290642
commit
c415240db6
|
@ -18,6 +18,17 @@ jobs:
|
|||
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
|
||||
timeout-minutes: 40
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
repository: hpcaitech/TensorNVMe
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
path: TensorNVMe
|
||||
- name: Install tensornvme
|
||||
run: |
|
||||
cd TensorNVMe
|
||||
conda install cmake
|
||||
pip install -r requirements.txt
|
||||
pip install -v .
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
|
@ -35,3 +46,4 @@ jobs:
|
|||
env:
|
||||
DATA: /data/scratch/cifar-10
|
||||
NCCL_SHM_DISABLE: 1
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
|
|
|
@ -16,6 +16,17 @@ jobs:
|
|||
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
|
||||
timeout-minutes: 40
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
repository: hpcaitech/TensorNVMe
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
path: TensorNVMe
|
||||
- name: Install tensornvme
|
||||
run: |
|
||||
cd TensorNVMe
|
||||
conda install cmake
|
||||
pip install -r requirements.txt
|
||||
pip install -v .
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
|
@ -33,4 +44,5 @@ jobs:
|
|||
[ "$gpu_used" -le "100" ] && PYTHONPATH=$PWD pytest tests
|
||||
env:
|
||||
DATA: /data/scratch/cifar-10
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
|
|
@ -3,10 +3,12 @@ import torch
|
|||
|
||||
from colossalai.registry import OPTIMIZERS
|
||||
from colossalai.nn.optimizer import CPU_ADAM_CNT
|
||||
from .nvme_optimizer import NVMeOptimizer
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@OPTIMIZERS.register_module
|
||||
class CPUAdam(torch.optim.Optimizer):
|
||||
class CPUAdam(NVMeOptimizer):
|
||||
"""Implements Adam algorithm.
|
||||
|
||||
Supports parameters updating on both GPU and CPU, depanding on the device of paramters.
|
||||
|
@ -45,6 +47,9 @@ class CPUAdam(torch.optim.Optimizer):
|
|||
True for decoupled weight decay(also known as AdamW) (default: True)
|
||||
simd_log (boolean, optional): whether to show if you are using SIMD to
|
||||
accelerate. (default: False)
|
||||
nvme_offload_fraction (float, optional): Fraction of params to be offloaded to NVMe. Defaults to 0.0.
|
||||
offload_dir (Optional[str], optional): Directory to save NVMe offload files.
|
||||
If it's ``None``, a random temporary directory will be used. Defaults to None.
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
|
@ -64,10 +69,12 @@ class CPUAdam(torch.optim.Optimizer):
|
|||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
adamw_mode=True,
|
||||
simd_log=False):
|
||||
simd_log=False,
|
||||
nvme_offload_fraction: float = 0.0,
|
||||
nvme_offload_dir: Optional[str] = None):
|
||||
|
||||
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
||||
super(CPUAdam, self).__init__(model_params, default_args)
|
||||
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||
self.opt_id = CPU_ADAM_CNT()
|
||||
self.adamw_mode = adamw_mode
|
||||
try:
|
||||
|
@ -78,7 +85,8 @@ class CPUAdam(torch.optim.Optimizer):
|
|||
self.cpu_adam_op.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, simd_log)
|
||||
|
||||
def __del__(self):
|
||||
if self.cpu_adam_op:
|
||||
super().__del__()
|
||||
if getattr(self, 'cpu_adam_op', None):
|
||||
self.cpu_adam_op.destroy_adam(self.opt_id)
|
||||
|
||||
def torch_adam_update(self,
|
||||
|
@ -121,6 +129,7 @@ class CPUAdam(torch.optim.Optimizer):
|
|||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
self._pre_step('exp_avg', 'exp_avg_sq')
|
||||
for _, group in enumerate(self.param_groups):
|
||||
for _, p in enumerate(group['params']):
|
||||
|
||||
|
@ -137,6 +146,7 @@ class CPUAdam(torch.optim.Optimizer):
|
|||
state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
|
||||
# gradient variances
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
|
||||
self._post_state_init(p)
|
||||
|
||||
state['step'] += 1
|
||||
beta1, beta2 = group['betas']
|
||||
|
@ -145,9 +155,11 @@ class CPUAdam(torch.optim.Optimizer):
|
|||
assert p.data.numel() == p.grad.data.numel(), "parameter and gradient should have the same size"
|
||||
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||
state['exp_avg'], state['exp_avg_sq'], -1)
|
||||
self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
elif target_device.type == 'cuda':
|
||||
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||
|
@ -161,4 +173,5 @@ class CPUAdam(torch.optim.Optimizer):
|
|||
bias_correction2, self.adamw_mode)
|
||||
else:
|
||||
raise RuntimeError
|
||||
self._post_step()
|
||||
return loss
|
||||
|
|
|
@ -3,10 +3,12 @@ import torch
|
|||
from colossalai.utils import multi_tensor_applier
|
||||
from colossalai.registry import OPTIMIZERS
|
||||
from colossalai.nn.optimizer import CPU_ADAM_CNT
|
||||
from typing import Optional
|
||||
from .nvme_optimizer import NVMeOptimizer
|
||||
|
||||
|
||||
@OPTIMIZERS.register_module
|
||||
class HybridAdam(torch.optim.Optimizer):
|
||||
class HybridAdam(NVMeOptimizer):
|
||||
"""Implements Adam algorithm.
|
||||
|
||||
Supports parameters updating on both GPU and CPU, depanding on the device of paramters.
|
||||
|
@ -44,6 +46,9 @@ class HybridAdam(torch.optim.Optimizer):
|
|||
True for decoupled weight decay(also known as AdamW) (default: True)
|
||||
simd_log (boolean, optional): whether to show if you are using SIMD to
|
||||
accelerate. (default: False)
|
||||
nvme_offload_fraction (float, optional): Fraction of params to be offloaded to NVMe. Defaults to 0.0.
|
||||
offload_dir (Optional[str], optional): Directory to save NVMe offload files.
|
||||
If it's ``None``, a random temporary directory will be used. Defaults to None.
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
|
@ -63,10 +68,12 @@ class HybridAdam(torch.optim.Optimizer):
|
|||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
adamw_mode=True,
|
||||
simd_log=False):
|
||||
simd_log=False,
|
||||
nvme_offload_fraction: float = 0.0,
|
||||
nvme_offload_dir: Optional[str] = None):
|
||||
|
||||
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
||||
super(HybridAdam, self).__init__(model_params, default_args)
|
||||
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||
self.opt_id = CPU_ADAM_CNT()
|
||||
self.adamw_mode = adamw_mode
|
||||
try:
|
||||
|
@ -82,7 +89,8 @@ class HybridAdam(torch.optim.Optimizer):
|
|||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
||||
def __del__(self):
|
||||
if self.cpu_adam_op:
|
||||
super().__del__()
|
||||
if getattr(self, 'cpu_adam_op', None):
|
||||
self.cpu_adam_op.destroy_adam(self.opt_id)
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -92,6 +100,7 @@ class HybridAdam(torch.optim.Optimizer):
|
|||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
self._pre_step('exp_avg', 'exp_avg_sq')
|
||||
for _, group in enumerate(self.param_groups):
|
||||
g_l, p_l, m_l, v_l = [], [], [], []
|
||||
group_step = 0
|
||||
|
@ -110,6 +119,7 @@ class HybridAdam(torch.optim.Optimizer):
|
|||
state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
|
||||
# gradient variances
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
|
||||
self._post_state_init(p)
|
||||
|
||||
state['step'] += 1
|
||||
group_step = state['step']
|
||||
|
@ -118,9 +128,11 @@ class HybridAdam(torch.optim.Optimizer):
|
|||
if target_device.type == 'cpu':
|
||||
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||
state['exp_avg'], state['exp_avg_sq'], -1)
|
||||
self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
|
||||
elif target_device.type == 'cuda':
|
||||
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||
|
@ -140,4 +152,5 @@ class HybridAdam(torch.optim.Optimizer):
|
|||
multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'],
|
||||
group['betas'][0], group['betas'][1], group['eps'], group_step, adamw_mode,
|
||||
bias_correction, group['weight_decay'])
|
||||
self._post_step()
|
||||
return loss
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
import torch
|
||||
import os
|
||||
import tempfile
|
||||
import math
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import Optional, List, Dict, Callable
|
||||
|
||||
|
||||
class NVMeOptimizer(torch.optim.Optimizer):
|
||||
"""A base class for offloading optimizer states.
|
||||
|
||||
Args:
|
||||
params: parameters
|
||||
defaults (dict): default dict
|
||||
nvme_offload_fraction (float, optional): Fraction of params to be offloaded to NVMe. Defaults to 0.0.
|
||||
offload_dir (Optional[str], optional): Directory to save NVMe offload files.
|
||||
If it's ``None``, a random temporary directory will be used. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ImportError: Raise if ``tensornvme`` is not installed.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
params,
|
||||
defaults: dict,
|
||||
nvme_offload_fraction: float = 0.0,
|
||||
offload_dir: Optional[str] = None) -> None:
|
||||
assert 0.0 <= nvme_offload_fraction <= 1.0
|
||||
super().__init__(params, defaults)
|
||||
self.nvme_offload_fraction = float(nvme_offload_fraction)
|
||||
if self.nvme_offload_fraction > 0.0:
|
||||
try:
|
||||
from tensornvme import DiskOffloader
|
||||
from tensornvme._C import get_backends
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError('Please install tensornvme to use NVMeOptimizer')
|
||||
self.offload_dir = offload_dir or tempfile.mkdtemp()
|
||||
backend = 'uring' if 'uring' in get_backends() else 'aio'
|
||||
self.offloader = DiskOffloader(self.offload_dir, 8, backend=backend)
|
||||
else:
|
||||
self.offload_dir = None
|
||||
self.offloader = None
|
||||
self.is_on_nvme: Dict[Parameter, bool] = {}
|
||||
self.offloaded_numel: int = 0
|
||||
self.total_numel: int = self._get_numel()
|
||||
self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction)
|
||||
|
||||
self.prefetch_params: List[Parameter] = []
|
||||
self.param_to_prefetch_idx: Dict[Parameter, int] = {}
|
||||
|
||||
def _get_numel(self) -> int:
|
||||
numel = 0
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
numel += p.storage().size()
|
||||
return numel
|
||||
|
||||
def _post_state_init(self, param: Parameter) -> None:
|
||||
numel = param.storage().size()
|
||||
if self.offloader is not None and param.device.type == 'cpu' and numel + self.offloaded_numel <= self.can_offload_numel:
|
||||
self.is_on_nvme[param] = True
|
||||
self.offloaded_numel += numel
|
||||
else:
|
||||
self.is_on_nvme[param] = False
|
||||
|
||||
def _setup_prefetch_params(self) -> List[Parameter]:
|
||||
if self.offloader is None:
|
||||
return
|
||||
assert len(self.prefetch_params) == 0 and len(self.param_to_prefetch_idx) == 0
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
if len(self.state[p]) > 0 and self.is_on_nvme[p]:
|
||||
assert p.device.type == 'cpu'
|
||||
self.param_to_prefetch_idx[p] = len(self.prefetch_params)
|
||||
self.prefetch_params.append(p)
|
||||
|
||||
def _pre_step(self, *state_keys: str) -> None:
|
||||
self._setup_prefetch_params()
|
||||
if self.offloader is None or len(self.prefetch_params) == 0:
|
||||
return
|
||||
state = self.state[self.prefetch_params[0]]
|
||||
for key in state_keys:
|
||||
self.offloader.async_read(state[key])
|
||||
|
||||
def _pre_update(self, param: Parameter, *state_keys: str) -> None:
|
||||
if self.offloader is None or param not in self.param_to_prefetch_idx:
|
||||
return
|
||||
self.offloader.sync_read_events()
|
||||
idx = self.param_to_prefetch_idx[param]
|
||||
if idx + 1 < len(self.prefetch_params):
|
||||
state = self.state[self.prefetch_params[idx + 1]]
|
||||
for key in state_keys:
|
||||
self.offloader.async_read(state[key])
|
||||
|
||||
def _post_update(self, param: Parameter, *state_keys: str) -> None:
|
||||
if self.offloader is None:
|
||||
return
|
||||
self.offloader.sync_write_events()
|
||||
if self.is_on_nvme[param]:
|
||||
state = self.state[param]
|
||||
for key in state_keys:
|
||||
self.offloader.async_write(state[key])
|
||||
|
||||
def _post_step(self) -> None:
|
||||
if self.offloader is not None:
|
||||
self.offloader.synchronize()
|
||||
self.prefetch_params.clear()
|
||||
self.param_to_prefetch_idx.clear()
|
||||
|
||||
def step(self, closure: Optional[Callable[[], float]] = ...) -> Optional[float]:
|
||||
"""Performs a single optimization step (parameter update).
|
||||
|
||||
Example:
|
||||
|
||||
>>> self._pre_step('exp_avg', 'exp_avg_sq')
|
||||
>>> for group in self.param_groups:
|
||||
>>> for p in group['params']:
|
||||
>>> if p.grad is None:
|
||||
>>> continue
|
||||
>>> state = self.state[p]
|
||||
>>> if len(state) == 0:
|
||||
>>> state['exp_avg'] = ...
|
||||
>>> state['exp_avg_sq'] = ...
|
||||
>>> self._post_state_init(p)
|
||||
>>> if p.device.type == 'cpu':
|
||||
>>> self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
>>> adam()
|
||||
>>> self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
>>> else:
|
||||
>>> ...
|
||||
>>> self._post_step()
|
||||
|
||||
Args:
|
||||
closure (Optional[Callable[[], float]], optional): A closure that reevaluates the model and
|
||||
returns the loss. Optional for most optimizers.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
# TODO(ver217): design a new method to save state_dict. When using NVMe offload, this method may lead to OOM.
|
||||
if self.offloader is not None:
|
||||
raise NotImplementedError
|
||||
return super().state_dict()
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
# TODO(ver217): design a new method to load state_dict. When using NVMe offload, whole state_dict may not be able to fit in memory.
|
||||
if self.offloader is not None:
|
||||
raise NotImplementedError
|
||||
super().load_state_dict(state_dict)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if getattr(self, 'offloader', None) is not None:
|
||||
del self.offloader
|
||||
if os.path.exists(self.offload_dir):
|
||||
try:
|
||||
os.rmdir(self.offload_dir)
|
||||
except OSError:
|
||||
pass
|
|
@ -0,0 +1,46 @@
|
|||
import pytest
|
||||
import torch
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from colossalai.nn.optimizer import CPUAdam, HybridAdam
|
||||
|
||||
|
||||
def move_some_params_to_cuda(model, torch_model):
|
||||
model.embed.weight.data = model.embed.weight.cuda()
|
||||
torch_model.embed.weight.data = model.embed.weight.cuda()
|
||||
model.ln1.weight.data = model.ln1.weight.cuda()
|
||||
torch_model.ln1.weight.data = model.ln1.weight.cuda()
|
||||
|
||||
|
||||
def check_params_equal(model, torch_model):
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
assert torch.allclose(p, torch_p, atol=1e-3), f'diff: {torch.abs(p - torch_p)}'
|
||||
|
||||
|
||||
@pytest.mark.parametrize('nvme_offload_fraction', [0.0, 0.5, 1.0])
|
||||
@pytest.mark.parametrize('nvme_offload_dir', ['./offload', None])
|
||||
@pytest.mark.parametrize('adam_cls', [CPUAdam, HybridAdam])
|
||||
def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls):
|
||||
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
model = model_builder()
|
||||
torch_model = model_builder()
|
||||
move_some_params_to_cuda(model, torch_model)
|
||||
optimizer = adam_cls(model.parameters(),
|
||||
lr=0.1,
|
||||
nvme_offload_fraction=nvme_offload_fraction,
|
||||
nvme_offload_dir=nvme_offload_dir)
|
||||
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.1)
|
||||
with torch.no_grad():
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
torch_p.copy_(p)
|
||||
p.grad = torch.rand_like(p)
|
||||
torch_p.grad = p.grad
|
||||
|
||||
for _ in range(3):
|
||||
optimizer.step()
|
||||
torch_optimizer.step()
|
||||
check_params_equal(model, torch_model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_nvme_adam(0.5, './offload', CPUAdam)
|
Loading…
Reference in New Issue