ColossalAI/colossalai/utils/checkpoint_io/utils.py

224 lines
8.7 KiB
Python

import warnings
from copy import deepcopy
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Tuple
from torch import Tensor
from torch.nn import Module
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from .meta import ParamDistMeta
def run_if_not_none(fn: Callable[[Any], Any], arg: Any) -> Any:
if arg is not None:
return fn(arg)
def get_param_to_os(model: Module, optimizer: Optimizer) -> Dict[str, int]:
# ensure all params in optimizer are in model state dict
params_set = set(id(p) for p in model.parameters())
for group in optimizer.param_groups:
for p in group['params']:
assert id(p) in params_set
param_mappings = {}
start_index = 0
def get_group_mapping(group):
nonlocal start_index
param_mappings.update(
{id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings})
start_index += len(group['params'])
for g in optimizer.param_groups:
get_group_mapping(g)
return {k: param_mappings[id(p)] for k, p in model.named_parameters()}
def compute_optimizer_state_size(state: Dict[str, Any]) -> int:
size = 0
for v in state.values():
if isinstance(v, Tensor):
size += v.numel() * v.element_size()
return size
class ModelCheckpointSharder:
def __init__(self, max_shard_size: int) -> None:
self.max_shard_size = max_shard_size
self.buffer: Dict[str, Tensor] = {}
self.buffer_size: int = 0
def append(self, key: str, tensor: Tensor) -> Optional[dict]:
retval = None
if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size:
retval = self.buffer
self.buffer = {}
self.buffer_size = 0
self.buffer[key] = tensor
self.buffer_size += tensor.numel() * tensor.element_size()
return retval
def extend(self, state_dict: Dict[str, Tensor]) -> List[dict]:
shards = []
for key, tensor in state_dict.items():
shard = self.append(key, tensor)
run_if_not_none(shards.append, shard)
return shards
def complete(self) -> Optional[dict]:
return self.buffer if len(self.buffer) > 0 else None
class OptimizerCheckpointSharder:
def __init__(self, max_shard_size: int, param_groups: dict) -> None:
self.max_shard_size = max_shard_size
self.buffer: Dict[str, dict] = {'state': {}, 'param_groups': param_groups}
self.buffer_size: int = 0
self.returned_first: bool = False
def append(self, key: int, state: dict) -> Optional[dict]:
retval = None
if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size:
retval = self.buffer
self.buffer = {'state': {}}
self.buffer_size = 0
self.buffer['state'][key] = state
self.buffer_size += compute_optimizer_state_size(state)
return retval
def extend(self, state_dict: Dict[str, dict]) -> List[dict]:
shards = []
for key, state in state_dict['state'].items():
shard = self.append(key, state)
run_if_not_none(shards.append, shard)
return shards
def complete(self) -> Optional[dict]:
return self.buffer if len(self.buffer['state']) > 0 else None
def shard_checkpoint(max_shard_size: int,
model_state_dict: Dict[str, Tensor],
optimizer_state_dict: Optional[dict] = None,
param_to_os: Optional[dict] = None) -> Tuple[List[dict], List[dict]]:
has_optimizer: bool = False
if optimizer_state_dict is not None:
assert param_to_os is not None
os_to_param = {v: k for k, v in param_to_os.items()}
for os_key in optimizer_state_dict['state'].keys():
assert os_key in os_to_param
assert os_to_param[os_key] in model_state_dict
has_optimizer = True
model_sharder = ModelCheckpointSharder(max_shard_size)
model_shards = model_sharder.extend(model_state_dict)
run_if_not_none(model_shards.append, model_sharder.complete())
if not has_optimizer:
return model_shards, []
optimizer_sharder = OptimizerCheckpointSharder(max_shard_size, optimizer_state_dict['param_groups'])
optimizer_shards = optimizer_sharder.extend(optimizer_state_dict)
run_if_not_none(optimizer_shards.append, optimizer_sharder.complete())
return model_shards, optimizer_shards
def get_paired_os(model_state_dict: Dict[str, Tensor], optimizer_state_dict: dict, param_to_os: Dict[str, int]) -> dict:
os_to_param = {v: k for k, v in param_to_os.items()}
paired_os = {}
for idx, state in optimizer_state_dict['state'].items():
paired_os[idx] = {}
p = model_state_dict[os_to_param[idx]]
for k, v in state.items():
if isinstance(v, Tensor) and v.shape == p.shape:
paired_os[idx][k] = True
else:
paired_os[idx][k] = False
return paired_os
def build_checkpoints(max_size: int,
model: Module,
optimizer: Optional[Optimizer] = None,
param_to_os: Optional[Dict[str, int]] = None,
dist_meta: Optional[Dict[str, ParamDistMeta]] = None,
eliminate_replica: bool = False) -> Tuple[List[dict], List[dict], dict]:
save_global = dist_meta is None
model_state_dict = model.state_dict()
optimizer_state_dict = optimizer.state_dict() if optimizer else None
meta = {'dist_meta': dist_meta}
if optimizer:
param_to_os = param_to_os or get_param_to_os(model, optimizer)
paired_os = get_paired_os(model_state_dict, optimizer_state_dict, param_to_os)
meta['param_to_os'] = param_to_os
meta['paired_os'] = paired_os
if not save_global and eliminate_replica:
# filter dp replicated params
model_state_dict = {
k: v for k, v in model_state_dict.items() if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0
}
if optimizer:
optimizer_state_dict['state'] = {
param_to_os[k]: optimizer_state_dict['state'][param_to_os[k]]
for k in model_state_dict.keys()
if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0
}
meta['params'] = list(model_state_dict.keys())
if len(model_state_dict) == 0:
warnings.warn('model state dict is empty, checkpoint is not saved')
return [], [], meta
model_checkpoints, optimizer_checkpoints = shard_checkpoint(max_size, model_state_dict, optimizer_state_dict,
param_to_os)
return model_checkpoints, optimizer_checkpoints, meta
def is_duplicated_list(list_: List[Any]) -> bool:
if len(list_) == 0:
return True
elem = list_[0]
for x in list_[1:]:
if x != elem:
return False
return True
def copy_optimizer_state(src_state: dict, dest_state: dict) -> None:
for k, v in src_state.items():
if k in dest_state:
old_v = dest_state[k]
if isinstance(old_v, Tensor):
old_v.copy_(v)
else:
dest_state[k] = v
def optimizer_load_state_dict(optimizer: Optimizer, state_dict: dict, strict: bool = False) -> None:
assert optimizer.state_dict()['param_groups'] == state_dict['param_groups']
state_dict = deepcopy(state_dict)
groups = optimizer.param_groups
saved_groups = state_dict['param_groups']
idx_to_p: Dict[str, Parameter] = {
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
)), chain.from_iterable((g['params'] for g in groups)))
}
missing_keys = list(set(idx_to_p.keys()) - set(state_dict['state'].keys()))
unexpected_keys = []
error_msgs = []
for idx, state in state_dict['state'].items():
if idx in idx_to_p:
old_state = optimizer.state[idx_to_p[idx]]
copy_optimizer_state(state, old_state)
else:
unexpected_keys.append(idx)
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(optimizer.__class__.__name__,
"\n\t".join(error_msgs)))