import warnings from copy import deepcopy from itertools import chain from typing import Any, Callable, Dict, List, Optional, Tuple from torch import Tensor from torch.nn import Module from torch.nn.parameter import Parameter from torch.optim import Optimizer from .meta import ParamDistMeta def run_if_not_none(fn: Callable[[Any], Any], arg: Any) -> Any: if arg is not None: return fn(arg) def get_param_to_os(model: Module, optimizer: Optimizer) -> Dict[str, int]: # ensure all params in optimizer are in model state dict params_set = set(id(p) for p in model.parameters()) for group in optimizer.param_groups: for p in group['params']: assert id(p) in params_set param_mappings = {} start_index = 0 def get_group_mapping(group): nonlocal start_index param_mappings.update( {id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings}) start_index += len(group['params']) for g in optimizer.param_groups: get_group_mapping(g) return {k: param_mappings[id(p)] for k, p in model.named_parameters()} def compute_optimizer_state_size(state: Dict[str, Any]) -> int: size = 0 for v in state.values(): if isinstance(v, Tensor): size += v.numel() * v.element_size() return size class ModelCheckpointSharder: def __init__(self, max_shard_size: int) -> None: self.max_shard_size = max_shard_size self.buffer: Dict[str, Tensor] = {} self.buffer_size: int = 0 def append(self, key: str, tensor: Tensor) -> Optional[dict]: retval = None if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size: retval = self.buffer self.buffer = {} self.buffer_size = 0 self.buffer[key] = tensor self.buffer_size += tensor.numel() * tensor.element_size() return retval def extend(self, state_dict: Dict[str, Tensor]) -> List[dict]: shards = [] for key, tensor in state_dict.items(): shard = self.append(key, tensor) run_if_not_none(shards.append, shard) return shards def complete(self) -> Optional[dict]: return self.buffer if len(self.buffer) > 0 else None class OptimizerCheckpointSharder: def __init__(self, max_shard_size: int, param_groups: dict) -> None: self.max_shard_size = max_shard_size self.buffer: Dict[str, dict] = {'state': {}, 'param_groups': param_groups} self.buffer_size: int = 0 self.returned_first: bool = False def append(self, key: int, state: dict) -> Optional[dict]: retval = None if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size: retval = self.buffer self.buffer = {'state': {}} self.buffer_size = 0 self.buffer['state'][key] = state self.buffer_size += compute_optimizer_state_size(state) return retval def extend(self, state_dict: Dict[str, dict]) -> List[dict]: shards = [] for key, state in state_dict['state'].items(): shard = self.append(key, state) run_if_not_none(shards.append, shard) return shards def complete(self) -> Optional[dict]: return self.buffer if len(self.buffer['state']) > 0 else None def shard_checkpoint(max_shard_size: int, model_state_dict: Dict[str, Tensor], optimizer_state_dict: Optional[dict] = None, param_to_os: Optional[dict] = None) -> Tuple[List[dict], List[dict]]: has_optimizer: bool = False if optimizer_state_dict is not None: assert param_to_os is not None os_to_param = {v: k for k, v in param_to_os.items()} for os_key in optimizer_state_dict['state'].keys(): assert os_key in os_to_param assert os_to_param[os_key] in model_state_dict has_optimizer = True model_sharder = ModelCheckpointSharder(max_shard_size) model_shards = model_sharder.extend(model_state_dict) run_if_not_none(model_shards.append, model_sharder.complete()) if not has_optimizer: return model_shards, [] optimizer_sharder = OptimizerCheckpointSharder(max_shard_size, optimizer_state_dict['param_groups']) optimizer_shards = optimizer_sharder.extend(optimizer_state_dict) run_if_not_none(optimizer_shards.append, optimizer_sharder.complete()) return model_shards, optimizer_shards def get_paired_os(model_state_dict: Dict[str, Tensor], optimizer_state_dict: dict, param_to_os: Dict[str, int]) -> dict: os_to_param = {v: k for k, v in param_to_os.items()} paired_os = {} for idx, state in optimizer_state_dict['state'].items(): paired_os[idx] = {} p = model_state_dict[os_to_param[idx]] for k, v in state.items(): if isinstance(v, Tensor) and v.shape == p.shape: paired_os[idx][k] = True else: paired_os[idx][k] = False return paired_os def build_checkpoints(max_size: int, model: Module, optimizer: Optional[Optimizer] = None, param_to_os: Optional[Dict[str, int]] = None, dist_meta: Optional[Dict[str, ParamDistMeta]] = None, eliminate_replica: bool = False) -> Tuple[List[dict], List[dict], dict]: save_global = dist_meta is None model_state_dict = model.state_dict() optimizer_state_dict = optimizer.state_dict() if optimizer else None meta = {'dist_meta': dist_meta} if optimizer: param_to_os = param_to_os or get_param_to_os(model, optimizer) paired_os = get_paired_os(model_state_dict, optimizer_state_dict, param_to_os) meta['param_to_os'] = param_to_os meta['paired_os'] = paired_os if not save_global and eliminate_replica: # filter dp replicated params model_state_dict = { k: v for k, v in model_state_dict.items() if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0 } if optimizer: optimizer_state_dict['state'] = { param_to_os[k]: optimizer_state_dict['state'][param_to_os[k]] for k in model_state_dict.keys() if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0 } meta['params'] = list(model_state_dict.keys()) if len(model_state_dict) == 0: warnings.warn('model state dict is empty, checkpoint is not saved') return [], [], meta model_checkpoints, optimizer_checkpoints = shard_checkpoint(max_size, model_state_dict, optimizer_state_dict, param_to_os) return model_checkpoints, optimizer_checkpoints, meta def is_duplicated_list(list_: List[Any]) -> bool: if len(list_) == 0: return True elem = list_[0] for x in list_[1:]: if x != elem: return False return True def copy_optimizer_state(src_state: dict, dest_state: dict) -> None: for k, v in src_state.items(): if k in dest_state: old_v = dest_state[k] if isinstance(old_v, Tensor): old_v.copy_(v) else: dest_state[k] = v def optimizer_load_state_dict(optimizer: Optimizer, state_dict: dict, strict: bool = False) -> None: assert optimizer.state_dict()['param_groups'] == state_dict['param_groups'] state_dict = deepcopy(state_dict) groups = optimizer.param_groups saved_groups = state_dict['param_groups'] idx_to_p: Dict[str, Parameter] = { old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups )), chain.from_iterable((g['params'] for g in groups))) } missing_keys = list(set(idx_to_p.keys()) - set(state_dict['state'].keys())) unexpected_keys = [] error_msgs = [] for idx, state in state_dict['state'].items(): if idx in idx_to_p: old_state = optimizer.state[idx_to_p[idx]] copy_optimizer_state(state, old_state) else: unexpected_keys.append(idx) if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys))) if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(optimizer.__class__.__name__, "\n\t".join(error_msgs)))