|
|
|
import math
|
|
|
|
import os
|
|
|
|
import tempfile
|
|
|
|
from typing import Callable, Dict, List, Optional
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
|
|
|
|
|
|
class NVMeOptimizer(torch.optim.Optimizer):
|
|
|
|
"""A base class for offloading optimizer states.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
params: parameters
|
|
|
|
defaults (dict): default dict
|
|
|
|
nvme_offload_fraction (float, optional): Fraction of params to be offloaded to NVMe. Defaults to 0.0.
|
|
|
|
offload_dir (Optional[str], optional): Directory to save NVMe offload files.
|
|
|
|
If it's ``None``, a random temporary directory will be used. Defaults to None.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
ImportError: Raise if ``tensornvme`` is not installed.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
params,
|
|
|
|
defaults: dict,
|
|
|
|
nvme_offload_fraction: float = 0.0,
|
|
|
|
offload_dir: Optional[str] = None) -> None:
|
|
|
|
assert 0.0 <= nvme_offload_fraction <= 1.0
|
|
|
|
super().__init__(params, defaults)
|
|
|
|
self.nvme_offload_fraction = float(nvme_offload_fraction)
|
|
|
|
if self.nvme_offload_fraction > 0.0:
|
|
|
|
try:
|
|
|
|
from tensornvme import DiskOffloader
|
|
|
|
from tensornvme._C import get_backends
|
|
|
|
except ModuleNotFoundError:
|
|
|
|
raise ModuleNotFoundError('Please install tensornvme to use NVMeOptimizer')
|
|
|
|
self.offload_dir = offload_dir or tempfile.mkdtemp()
|
|
|
|
backend = 'uring' if 'uring' in get_backends() else 'aio'
|
|
|
|
self.offloader = DiskOffloader(self.offload_dir, 8, backend=backend)
|
|
|
|
else:
|
|
|
|
self.offload_dir = None
|
|
|
|
self.offloader = None
|
|
|
|
self.is_on_nvme: Dict[Parameter, bool] = {}
|
|
|
|
self.offloaded_numel: int = 0
|
|
|
|
# As param may be not materialized here, these attributes are initalized when the first step
|
|
|
|
self.total_numel: Optional[int] = None
|
|
|
|
self.can_offload_numel: Optional[int] = None
|
|
|
|
|
|
|
|
self.prefetch_params: List[Parameter] = []
|
|
|
|
self.param_to_prefetch_idx: Dict[Parameter, int] = {}
|
|
|
|
|
|
|
|
def _get_numel(self) -> int:
|
|
|
|
numel = 0
|
|
|
|
for group in self.param_groups:
|
|
|
|
for p in group['params']:
|
|
|
|
numel += p.storage().size()
|
|
|
|
return numel
|
|
|
|
|
|
|
|
def _post_state_init(self, param: Parameter) -> None:
|
|
|
|
numel = param.storage().size()
|
|
|
|
if self.offloader is not None and param.device.type == 'cpu' and numel + self.offloaded_numel <= self.can_offload_numel:
|
|
|
|
self.is_on_nvme[param] = True
|
|
|
|
self.offloaded_numel += numel
|
|
|
|
else:
|
|
|
|
self.is_on_nvme[param] = False
|
|
|
|
|
|
|
|
def _setup_prefetch_params(self) -> List[Parameter]:
|
|
|
|
if self.offloader is None:
|
|
|
|
return
|
|
|
|
assert len(self.prefetch_params) == 0 and len(self.param_to_prefetch_idx) == 0
|
|
|
|
for group in self.param_groups:
|
|
|
|
for p in group['params']:
|
|
|
|
if p.grad is None:
|
|
|
|
continue
|
|
|
|
if len(self.state[p]) > 0 and self.is_on_nvme[p]:
|
|
|
|
assert p.device.type == 'cpu'
|
|
|
|
self.param_to_prefetch_idx[p] = len(self.prefetch_params)
|
|
|
|
self.prefetch_params.append(p)
|
|
|
|
|
|
|
|
def _pre_step(self, *state_keys: str) -> None:
|
|
|
|
if self.total_numel is None:
|
|
|
|
self.total_numel = self._get_numel()
|
|
|
|
self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction)
|
|
|
|
self._setup_prefetch_params()
|
|
|
|
if self.offloader is None or len(self.prefetch_params) == 0:
|
|
|
|
return
|
|
|
|
state = self.state[self.prefetch_params[0]]
|
|
|
|
for key in state_keys:
|
|
|
|
self.offloader.async_read(state[key])
|
|
|
|
|
|
|
|
def _pre_update(self, param: Parameter, *state_keys: str) -> None:
|
|
|
|
if self.offloader is None or param not in self.param_to_prefetch_idx:
|
|
|
|
return
|
|
|
|
self.offloader.sync_read_events()
|
|
|
|
idx = self.param_to_prefetch_idx[param]
|
|
|
|
if idx + 1 < len(self.prefetch_params):
|
|
|
|
state = self.state[self.prefetch_params[idx + 1]]
|
|
|
|
for key in state_keys:
|
|
|
|
self.offloader.async_read(state[key])
|
|
|
|
|
|
|
|
def _post_update(self, param: Parameter, *state_keys: str) -> None:
|
|
|
|
if self.offloader is None:
|
|
|
|
return
|
|
|
|
self.offloader.sync_write_events()
|
|
|
|
if self.is_on_nvme[param]:
|
|
|
|
state = self.state[param]
|
|
|
|
for key in state_keys:
|
|
|
|
self.offloader.async_write(state[key])
|
|
|
|
|
|
|
|
def _post_step(self) -> None:
|
|
|
|
if self.offloader is not None:
|
|
|
|
self.offloader.synchronize()
|
|
|
|
self.prefetch_params.clear()
|
|
|
|
self.param_to_prefetch_idx.clear()
|
|
|
|
|
|
|
|
def step(self, closure: Optional[Callable[[], float]] = ...) -> Optional[float]:
|
|
|
|
"""Performs a single optimization step (parameter update).
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
>>> self._pre_step('exp_avg', 'exp_avg_sq')
|
|
|
|
>>> for group in self.param_groups:
|
|
|
|
>>> for p in group['params']:
|
|
|
|
>>> if p.grad is None:
|
|
|
|
>>> continue
|
|
|
|
>>> state = self.state[p]
|
|
|
|
>>> if len(state) == 0:
|
|
|
|
>>> state['exp_avg'] = ...
|
|
|
|
>>> state['exp_avg_sq'] = ...
|
|
|
|
>>> self._post_state_init(p)
|
|
|
|
>>> if p.device.type == 'cpu':
|
|
|
|
>>> self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
|
|
|
>>> adam()
|
|
|
|
>>> self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
|
|
|
>>> else:
|
|
|
|
>>> ...
|
|
|
|
>>> self._post_step()
|
|
|
|
|
|
|
|
Args:
|
|
|
|
closure (Optional[Callable[[], float]], optional): A closure that reevaluates the model and
|
|
|
|
returns the loss. Optional for most optimizers.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def state_dict(self) -> dict:
|
|
|
|
# TODO(ver217): design a new method to save state_dict. When using NVMe offload, this method may lead to OOM.
|
|
|
|
if self.offloader is not None:
|
|
|
|
raise NotImplementedError
|
|
|
|
return super().state_dict()
|
|
|
|
|
|
|
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
|
|
# TODO(ver217): design a new method to load state_dict. When using NVMe offload, whole state_dict may not be able to fit in memory.
|
|
|
|
if self.offloader is not None:
|
|
|
|
raise NotImplementedError
|
|
|
|
super().load_state_dict(state_dict)
|
|
|
|
|
|
|
|
def __del__(self) -> None:
|
|
|
|
if getattr(self, 'offloader', None) is not None:
|
|
|
|
del self.offloader
|
|
|
|
if os.path.exists(self.offload_dir):
|
|
|
|
try:
|
|
|
|
os.rmdir(self.offload_dir)
|
|
|
|
except OSError:
|
|
|
|
pass
|