mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
224 lines
8.7 KiB
224 lines
8.7 KiB
import warnings
|
|
from copy import deepcopy
|
|
from itertools import chain
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
from torch import Tensor
|
|
from torch.nn import Module
|
|
from torch.nn.parameter import Parameter
|
|
from torch.optim import Optimizer
|
|
|
|
from .meta import ParamDistMeta
|
|
|
|
|
|
def run_if_not_none(fn: Callable[[Any], Any], arg: Any) -> Any:
|
|
if arg is not None:
|
|
return fn(arg)
|
|
|
|
|
|
def get_param_to_os(model: Module, optimizer: Optimizer) -> Dict[str, int]:
|
|
# ensure all params in optimizer are in model state dict
|
|
params_set = set(id(p) for p in model.parameters())
|
|
for group in optimizer.param_groups:
|
|
for p in group['params']:
|
|
assert id(p) in params_set
|
|
param_mappings = {}
|
|
start_index = 0
|
|
|
|
def get_group_mapping(group):
|
|
nonlocal start_index
|
|
param_mappings.update(
|
|
{id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings})
|
|
start_index += len(group['params'])
|
|
|
|
for g in optimizer.param_groups:
|
|
get_group_mapping(g)
|
|
return {k: param_mappings[id(p)] for k, p in model.named_parameters()}
|
|
|
|
|
|
def compute_optimizer_state_size(state: Dict[str, Any]) -> int:
|
|
size = 0
|
|
for v in state.values():
|
|
if isinstance(v, Tensor):
|
|
size += v.numel() * v.element_size()
|
|
return size
|
|
|
|
|
|
class ModelCheckpointSharder:
|
|
|
|
def __init__(self, max_shard_size: int) -> None:
|
|
self.max_shard_size = max_shard_size
|
|
self.buffer: Dict[str, Tensor] = {}
|
|
self.buffer_size: int = 0
|
|
|
|
def append(self, key: str, tensor: Tensor) -> Optional[dict]:
|
|
retval = None
|
|
if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size:
|
|
retval = self.buffer
|
|
self.buffer = {}
|
|
self.buffer_size = 0
|
|
self.buffer[key] = tensor
|
|
self.buffer_size += tensor.numel() * tensor.element_size()
|
|
return retval
|
|
|
|
def extend(self, state_dict: Dict[str, Tensor]) -> List[dict]:
|
|
shards = []
|
|
for key, tensor in state_dict.items():
|
|
shard = self.append(key, tensor)
|
|
run_if_not_none(shards.append, shard)
|
|
return shards
|
|
|
|
def complete(self) -> Optional[dict]:
|
|
return self.buffer if len(self.buffer) > 0 else None
|
|
|
|
|
|
class OptimizerCheckpointSharder:
|
|
|
|
def __init__(self, max_shard_size: int, param_groups: dict) -> None:
|
|
self.max_shard_size = max_shard_size
|
|
self.buffer: Dict[str, dict] = {'state': {}, 'param_groups': param_groups}
|
|
self.buffer_size: int = 0
|
|
self.returned_first: bool = False
|
|
|
|
def append(self, key: int, state: dict) -> Optional[dict]:
|
|
retval = None
|
|
if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size:
|
|
retval = self.buffer
|
|
self.buffer = {'state': {}}
|
|
self.buffer_size = 0
|
|
self.buffer['state'][key] = state
|
|
self.buffer_size += compute_optimizer_state_size(state)
|
|
return retval
|
|
|
|
def extend(self, state_dict: Dict[str, dict]) -> List[dict]:
|
|
shards = []
|
|
for key, state in state_dict['state'].items():
|
|
shard = self.append(key, state)
|
|
run_if_not_none(shards.append, shard)
|
|
return shards
|
|
|
|
def complete(self) -> Optional[dict]:
|
|
return self.buffer if len(self.buffer['state']) > 0 else None
|
|
|
|
|
|
def shard_checkpoint(max_shard_size: int,
|
|
model_state_dict: Dict[str, Tensor],
|
|
optimizer_state_dict: Optional[dict] = None,
|
|
param_to_os: Optional[dict] = None) -> Tuple[List[dict], List[dict]]:
|
|
has_optimizer: bool = False
|
|
if optimizer_state_dict is not None:
|
|
assert param_to_os is not None
|
|
os_to_param = {v: k for k, v in param_to_os.items()}
|
|
for os_key in optimizer_state_dict['state'].keys():
|
|
assert os_key in os_to_param
|
|
assert os_to_param[os_key] in model_state_dict
|
|
has_optimizer = True
|
|
model_sharder = ModelCheckpointSharder(max_shard_size)
|
|
model_shards = model_sharder.extend(model_state_dict)
|
|
run_if_not_none(model_shards.append, model_sharder.complete())
|
|
if not has_optimizer:
|
|
return model_shards, []
|
|
optimizer_sharder = OptimizerCheckpointSharder(max_shard_size, optimizer_state_dict['param_groups'])
|
|
optimizer_shards = optimizer_sharder.extend(optimizer_state_dict)
|
|
run_if_not_none(optimizer_shards.append, optimizer_sharder.complete())
|
|
return model_shards, optimizer_shards
|
|
|
|
|
|
def get_paired_os(model_state_dict: Dict[str, Tensor], optimizer_state_dict: dict, param_to_os: Dict[str, int]) -> dict:
|
|
os_to_param = {v: k for k, v in param_to_os.items()}
|
|
paired_os = {}
|
|
for idx, state in optimizer_state_dict['state'].items():
|
|
paired_os[idx] = {}
|
|
p = model_state_dict[os_to_param[idx]]
|
|
for k, v in state.items():
|
|
if isinstance(v, Tensor) and v.shape == p.shape:
|
|
paired_os[idx][k] = True
|
|
else:
|
|
paired_os[idx][k] = False
|
|
return paired_os
|
|
|
|
|
|
def build_checkpoints(max_size: int,
|
|
model: Module,
|
|
optimizer: Optional[Optimizer] = None,
|
|
param_to_os: Optional[Dict[str, int]] = None,
|
|
dist_meta: Optional[Dict[str, ParamDistMeta]] = None,
|
|
eliminate_replica: bool = False) -> Tuple[List[dict], List[dict], dict]:
|
|
save_global = dist_meta is None
|
|
model_state_dict = model.state_dict()
|
|
optimizer_state_dict = optimizer.state_dict() if optimizer else None
|
|
meta = {'dist_meta': dist_meta}
|
|
if optimizer:
|
|
param_to_os = param_to_os or get_param_to_os(model, optimizer)
|
|
paired_os = get_paired_os(model_state_dict, optimizer_state_dict, param_to_os)
|
|
meta['param_to_os'] = param_to_os
|
|
meta['paired_os'] = paired_os
|
|
if not save_global and eliminate_replica:
|
|
# filter dp replicated params
|
|
model_state_dict = {
|
|
k: v for k, v in model_state_dict.items() if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0
|
|
}
|
|
if optimizer:
|
|
optimizer_state_dict['state'] = {
|
|
param_to_os[k]: optimizer_state_dict['state'][param_to_os[k]]
|
|
for k in model_state_dict.keys()
|
|
if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0
|
|
}
|
|
meta['params'] = list(model_state_dict.keys())
|
|
if len(model_state_dict) == 0:
|
|
warnings.warn('model state dict is empty, checkpoint is not saved')
|
|
return [], [], meta
|
|
model_checkpoints, optimizer_checkpoints = shard_checkpoint(max_size, model_state_dict, optimizer_state_dict,
|
|
param_to_os)
|
|
return model_checkpoints, optimizer_checkpoints, meta
|
|
|
|
|
|
def is_duplicated_list(list_: List[Any]) -> bool:
|
|
if len(list_) == 0:
|
|
return True
|
|
elem = list_[0]
|
|
for x in list_[1:]:
|
|
if x != elem:
|
|
return False
|
|
return True
|
|
|
|
|
|
def copy_optimizer_state(src_state: dict, dest_state: dict) -> None:
|
|
for k, v in src_state.items():
|
|
if k in dest_state:
|
|
old_v = dest_state[k]
|
|
if isinstance(old_v, Tensor):
|
|
old_v.copy_(v)
|
|
else:
|
|
dest_state[k] = v
|
|
|
|
|
|
def optimizer_load_state_dict(optimizer: Optimizer, state_dict: dict, strict: bool = False) -> None:
|
|
assert optimizer.state_dict()['param_groups'] == state_dict['param_groups']
|
|
state_dict = deepcopy(state_dict)
|
|
groups = optimizer.param_groups
|
|
saved_groups = state_dict['param_groups']
|
|
idx_to_p: Dict[str, Parameter] = {
|
|
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
|
|
)), chain.from_iterable((g['params'] for g in groups)))
|
|
}
|
|
missing_keys = list(set(idx_to_p.keys()) - set(state_dict['state'].keys()))
|
|
unexpected_keys = []
|
|
error_msgs = []
|
|
for idx, state in state_dict['state'].items():
|
|
if idx in idx_to_p:
|
|
old_state = optimizer.state[idx_to_p[idx]]
|
|
copy_optimizer_state(state, old_state)
|
|
else:
|
|
unexpected_keys.append(idx)
|
|
if strict:
|
|
if len(unexpected_keys) > 0:
|
|
error_msgs.insert(
|
|
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in unexpected_keys)))
|
|
if len(missing_keys) > 0:
|
|
error_msgs.insert(
|
|
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
|
|
if len(error_msgs) > 0:
|
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(optimizer.__class__.__name__,
|
|
"\n\t".join(error_msgs)))
|