import warnings from typing import Any, Callable, Dict, Generator, List, Optional, Tuple import torch.distributed as dist from torch.nn import Module from torch.optim import Optimizer from .backend import get_backend from .convertor import (CheckpointConvertor, ModelCheckpointMerger, ModelCheckpointRedistor, OptimizerCheckpointMerger, OptimizerCheckpointRedistor) from .meta import ParamDistMeta, RedistMeta from .utils import build_checkpoints, optimizer_load_state_dict def save(path: str, model: Module, optimizer: Optional[Optimizer] = None, param_to_os: Optional[Dict[str, int]] = None, dist_meta: Optional[Dict[str, ParamDistMeta]] = None, max_shard_size_gb: float = 0.0, overwrite: bool = False, backend: str = 'disk', **kwargs: Any) -> None: io_backend = get_backend(backend) if dist.is_initialized(): rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 if world_size == 1: # global doesn't need dist_meta dist_meta = None else: assert dist_meta is not None max_shard_size = int(max_shard_size_gb * 1024**3) model_checkpoints, optimizer_checkpoints, meta_checkpoint = build_checkpoints(max_shard_size, model, optimizer, param_to_os, dist_meta) writer = io_backend.get_writer(path, overwrite, rank, world_size) writer.save_others(kwargs) for model_checkpoint in model_checkpoints: writer.save_model(model_checkpoint) for optimizer_checkpoint in optimizer_checkpoints: writer.save_optimizer(optimizer_checkpoint) writer.save_meta(meta_checkpoint) def merge(path: str, output_path: str, max_shard_size_gb: float = 0.0, overwrite: bool = False, backend: str = 'disk') -> bool: io_backend = get_backend(backend) if dist.is_initialized() and dist.get_rank() != 0: return False reader = io_backend.get_reader(path) if len(reader.meta_list) == 1: # already global warnings.warn(f'Checkpoint at "{path}" is already global, nothing to do.') return False dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta() writer = io_backend.get_writer(output_path, overwrite=overwrite) writer.save_others(reader.load_others()) max_shard_size = int(max_shard_size_gb * 1024**3) _convert_shards(ModelCheckpointMerger(max_shard_size, writer.save_model, param_count), reader.load_models(), dist_meta_list) _convert_shards( OptimizerCheckpointMerger(max_shard_size, writer.save_optimizer, param_count, param_to_os, paired_os), reader.load_optimizers(), dist_meta_list) meta_checkpoint = {'dist_meta': None, 'params': list(param_count.keys())} if param_to_os is not None: meta_checkpoint['param_to_os'] = param_to_os meta_checkpoint['paired_os'] = paired_os writer.save_meta(meta_checkpoint) return True def redist(path: str, output_path: str, redist_meta: RedistMeta, dist_metas: List[Dict[str, ParamDistMeta]], max_shard_size_gb: float = 0.0, overwrite: bool = False, backend: str = 'disk') -> bool: io_backend = get_backend(backend) if dist.is_initialized() and dist.get_rank() != 0: return False nprocs = len(dist_metas) reader = io_backend.get_reader(path) dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta() do_redist: bool = False if len(dist_meta_list) == nprocs: for a, b in zip(dist_metas, dist_meta_list): if a != b: do_redist = True break else: do_redist = True if not do_redist: warnings.warn(f'Checkpoint at "{path}" is not required to redist, nothing to do.') return False writers = [io_backend.get_writer(output_path, overwrite, rank, nprocs) for rank in range(nprocs)] writers[0].save_others(reader.load_others()) max_shard_size = int(max_shard_size_gb * 1024**3) _convert_shards( ModelCheckpointRedistor(max_shard_size, [writer.save_model for writer in writers], param_count, redist_meta), reader.load_models(), dist_meta_list) _convert_shards( OptimizerCheckpointRedistor(max_shard_size, [writer.save_optimizer for writer in writers], param_count, param_to_os, paired_os, redist_meta), reader.load_optimizers(), dist_meta_list) for writer, dist_meta in zip(writers, dist_metas): meta_checkpoint = {'dist_meta': dist_meta, 'params': list(param_count.keys())} if param_to_os is not None: meta_checkpoint['param_to_os'] = param_to_os meta_checkpoint['paired_os'] = paired_os writer.save_meta(meta_checkpoint) return True def _convert_shards(convertor: CheckpointConvertor, shard_generator: Generator[dict, None, None], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: for shard_dict in shard_generator: convertor.append(shard_dict, dist_meta_list) convertor.complete() def load(path: str, model: Module, optimizer: Optional[Optimizer] = None, redist_meta: Optional[RedistMeta] = None, dist_metas: Optional[List[Dict[str, ParamDistMeta]]] = None, max_shard_size_gb: float = 0.0, backend: str = 'disk') -> dict: is_global: bool = not dist.is_initialized() or dist.get_world_size() == 1 rank: int = dist.get_rank() if dist.is_initialized() else 0 is_main_process: bool = rank == 0 # validate args if redist_meta is None or dist_metas is None: assert is_global io_backend = get_backend(backend) read_path: str = path if is_main_process: # pre-process checkpoints temp_path = io_backend.get_temp(path) if is_global: wrote = merge(path, temp_path, max_shard_size_gb, backend=backend) else: wrote = redist(path, temp_path, redist_meta, dist_metas, max_shard_size_gb, backend=backend) if wrote: read_path = temp_path if not is_global: bcast_list = [read_path] if is_main_process else [None] dist.broadcast_object_list(bcast_list) read_path = bcast_list[0] reader = io_backend.get_reader(read_path) # load model for shard in reader.load_model(rank): model.load_state_dict(shard, strict=False) if optimizer is not None: for shard in reader.load_optimizer(rank): # optimizer.load_state_dict(shard) optimizer_load_state_dict(optimizer, shard) others_dict = reader.load_others() if not is_global: dist.barrier() # clean up temp if is_main_process: io_backend.clean_temp() return others_dict