ColossalAI/colossalai/nn/optimizer/nvme_optimizer.py

168 lines
6.6 KiB
Python
Raw Permalink Normal View History

import math
import os
import tempfile
from typing import Callable, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
class NVMeOptimizer(torch.optim.Optimizer):
"""A base class for offloading optimizer states.
Args:
params: parameters
defaults (dict): default dict
nvme_offload_fraction (float, optional): Fraction of params to be offloaded to NVMe. Defaults to 0.0.
offload_dir (Optional[str], optional): Directory to save NVMe offload files.
If it's ``None``, a random temporary directory will be used. Defaults to None.
Raises:
ImportError: Raise if ``tensornvme`` is not installed.
"""
def __init__(
self, params, defaults: dict, nvme_offload_fraction: float = 0.0, offload_dir: Optional[str] = None
) -> None:
assert 0.0 <= nvme_offload_fraction <= 1.0
super().__init__(params, defaults)
self.nvme_offload_fraction = float(nvme_offload_fraction)
if self.nvme_offload_fraction > 0.0:
try:
from tensornvme import DiskOffloader
from tensornvme._C import get_backends
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
self.offload_dir = offload_dir or tempfile.mkdtemp()
backend = "uring" if "uring" in get_backends() else "aio"
self.offloader = DiskOffloader(self.offload_dir, 8, backend=backend)
else:
self.offload_dir = None
self.offloader = None
self.is_on_nvme: Dict[Parameter, bool] = {}
self.offloaded_numel: int = 0
# As param may be not materialized here, these attributes are initialized when the first step
self.total_numel: Optional[int] = None
self.can_offload_numel: Optional[int] = None
self.prefetch_params: List[Parameter] = []
self.param_to_prefetch_idx: Dict[Parameter, int] = {}
def _get_numel(self) -> int:
numel = 0
for group in self.param_groups:
for p in group["params"]:
numel += p.storage().size()
return numel
def _post_state_init(self, param: Parameter) -> None:
numel = param.storage().size()
if (
self.offloader is not None
and param.device.type == "cpu"
and numel + self.offloaded_numel <= self.can_offload_numel
):
self.is_on_nvme[param] = True
self.offloaded_numel += numel
else:
self.is_on_nvme[param] = False
def _setup_prefetch_params(self) -> List[Parameter]:
if self.offloader is None:
return
assert len(self.prefetch_params) == 0 and len(self.param_to_prefetch_idx) == 0
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
if len(self.state[p]) > 0 and self.is_on_nvme[p]:
assert p.device.type == "cpu"
self.param_to_prefetch_idx[p] = len(self.prefetch_params)
self.prefetch_params.append(p)
def _pre_step(self, *state_keys: str) -> None:
if self.total_numel is None:
self.total_numel = self._get_numel()
self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction)
self._setup_prefetch_params()
if self.offloader is None or len(self.prefetch_params) == 0:
return
state = self.state[self.prefetch_params[0]]
for key in state_keys:
self.offloader.async_read(state[key])
def _pre_update(self, param: Parameter, *state_keys: str) -> None:
if self.offloader is None or param not in self.param_to_prefetch_idx:
return
self.offloader.sync_read_events()
idx = self.param_to_prefetch_idx[param]
if idx + 1 < len(self.prefetch_params):
state = self.state[self.prefetch_params[idx + 1]]
for key in state_keys:
self.offloader.async_read(state[key])
def _post_update(self, param: Parameter, *state_keys: str) -> None:
if self.offloader is None:
return
self.offloader.sync_write_events()
if self.is_on_nvme[param]:
state = self.state[param]
for key in state_keys:
self.offloader.async_write(state[key])
def _post_step(self) -> None:
if self.offloader is not None:
self.offloader.synchronize()
self.prefetch_params.clear()
self.param_to_prefetch_idx.clear()
def step(self, closure: Optional[Callable[[], float]] = ...) -> Optional[float]:
"""Performs a single optimization step (parameter update).
Example:
>>> self._pre_step('exp_avg', 'exp_avg_sq')
>>> for group in self.param_groups:
>>> for p in group['params']:
>>> if p.grad is None:
>>> continue
>>> state = self.state[p]
>>> if len(state) == 0:
>>> state['exp_avg'] = ...
>>> state['exp_avg_sq'] = ...
>>> self._post_state_init(p)
>>> if p.device.type == 'cpu':
>>> self._pre_update(p, 'exp_avg', 'exp_avg_sq')
>>> adam()
>>> self._post_update(p, 'exp_avg', 'exp_avg_sq')
>>> else:
>>> ...
>>> self._post_step()
Args:
closure (Optional[Callable[[], float]], optional): A closure that reevaluates the model and
returns the loss. Optional for most optimizers.
"""
raise NotImplementedError
def state_dict(self) -> dict:
# TODO(ver217): design a new method to save state_dict. When using NVMe offload, this method may lead to OOM.
if self.offloader is not None:
raise NotImplementedError
return super().state_dict()
def load_state_dict(self, state_dict: dict) -> None:
# TODO(ver217): design a new method to load state_dict. When using NVMe offload, whole state_dict may not be able to fit in memory.
if self.offloader is not None:
raise NotImplementedError
super().load_state_dict(state_dict)
def __del__(self) -> None:
if getattr(self, "offloader", None) is not None:
del self.offloader
if os.path.exists(self.offload_dir):
try:
os.rmdir(self.offload_dir)
except OSError:
pass