mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
171 lines
6.8 KiB
171 lines
6.8 KiB
2 years ago
|
import warnings
|
||
|
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
|
||
|
|
||
|
import torch.distributed as dist
|
||
|
from torch.nn import Module
|
||
|
from torch.optim import Optimizer
|
||
|
|
||
|
from .backend import get_backend
|
||
|
from .convertor import (CheckpointConvertor, ModelCheckpointMerger, ModelCheckpointRedistor, OptimizerCheckpointMerger,
|
||
|
OptimizerCheckpointRedistor)
|
||
|
from .meta import ParamDistMeta, RedistMeta
|
||
|
from .utils import build_checkpoints, optimizer_load_state_dict
|
||
|
|
||
|
|
||
|
def save(path: str,
|
||
|
model: Module,
|
||
|
optimizer: Optional[Optimizer] = None,
|
||
|
param_to_os: Optional[Dict[str, int]] = None,
|
||
|
dist_meta: Optional[Dict[str, ParamDistMeta]] = None,
|
||
|
max_shard_size_gb: float = 0.0,
|
||
|
overwrite: bool = False,
|
||
|
backend: str = 'disk',
|
||
|
**kwargs: Any) -> None:
|
||
|
io_backend = get_backend(backend)
|
||
|
if dist.is_initialized():
|
||
|
rank = dist.get_rank()
|
||
|
world_size = dist.get_world_size()
|
||
|
else:
|
||
|
rank = 0
|
||
|
world_size = 1
|
||
|
if world_size == 1:
|
||
|
# global doesn't need dist_meta
|
||
|
dist_meta = None
|
||
|
else:
|
||
|
assert dist_meta is not None
|
||
|
max_shard_size = int(max_shard_size_gb * 1024**3)
|
||
|
model_checkpoints, optimizer_checkpoints, meta_checkpoint = build_checkpoints(max_shard_size, model, optimizer,
|
||
|
param_to_os, dist_meta)
|
||
|
writer = io_backend.get_writer(path, overwrite, rank, world_size)
|
||
|
writer.save_others(kwargs)
|
||
|
for model_checkpoint in model_checkpoints:
|
||
|
writer.save_model(model_checkpoint)
|
||
|
for optimizer_checkpoint in optimizer_checkpoints:
|
||
|
writer.save_optimizer(optimizer_checkpoint)
|
||
|
writer.save_meta(meta_checkpoint)
|
||
|
|
||
|
|
||
|
def merge(path: str,
|
||
|
output_path: str,
|
||
|
max_shard_size_gb: float = 0.0,
|
||
|
overwrite: bool = False,
|
||
|
backend: str = 'disk') -> bool:
|
||
|
io_backend = get_backend(backend)
|
||
|
if dist.is_initialized() and dist.get_rank() != 0:
|
||
|
return False
|
||
|
reader = io_backend.get_reader(path)
|
||
|
if len(reader.meta_list) == 1:
|
||
|
# already global
|
||
|
warnings.warn(f'Checkpoint at "{path}" is already global, nothing to do.')
|
||
|
return False
|
||
|
dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta()
|
||
|
writer = io_backend.get_writer(output_path, overwrite=overwrite)
|
||
|
writer.save_others(reader.load_others())
|
||
|
max_shard_size = int(max_shard_size_gb * 1024**3)
|
||
|
_convert_shards(ModelCheckpointMerger(max_shard_size, writer.save_model, param_count), reader.load_models(),
|
||
|
dist_meta_list)
|
||
|
_convert_shards(
|
||
|
OptimizerCheckpointMerger(max_shard_size, writer.save_optimizer, param_count, param_to_os, paired_os),
|
||
|
reader.load_optimizers(), dist_meta_list)
|
||
|
meta_checkpoint = {'dist_meta': None, 'params': list(param_count.keys())}
|
||
|
if param_to_os is not None:
|
||
|
meta_checkpoint['param_to_os'] = param_to_os
|
||
|
meta_checkpoint['paired_os'] = paired_os
|
||
|
writer.save_meta(meta_checkpoint)
|
||
|
return True
|
||
|
|
||
|
|
||
|
def redist(path: str,
|
||
|
output_path: str,
|
||
|
redist_meta: RedistMeta,
|
||
|
dist_metas: List[Dict[str, ParamDistMeta]],
|
||
|
max_shard_size_gb: float = 0.0,
|
||
|
overwrite: bool = False,
|
||
|
backend: str = 'disk') -> bool:
|
||
|
io_backend = get_backend(backend)
|
||
|
if dist.is_initialized() and dist.get_rank() != 0:
|
||
|
return False
|
||
|
nprocs = len(dist_metas)
|
||
|
reader = io_backend.get_reader(path)
|
||
|
dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta()
|
||
|
do_redist: bool = False
|
||
|
if len(dist_meta_list) == nprocs:
|
||
|
for a, b in zip(dist_metas, dist_meta_list):
|
||
|
if a != b:
|
||
|
do_redist = True
|
||
|
break
|
||
|
else:
|
||
|
do_redist = True
|
||
|
if not do_redist:
|
||
|
warnings.warn(f'Checkpoint at "{path}" is not required to redist, nothing to do.')
|
||
|
return False
|
||
|
|
||
|
writers = [io_backend.get_writer(output_path, overwrite, rank, nprocs) for rank in range(nprocs)]
|
||
|
writers[0].save_others(reader.load_others())
|
||
|
max_shard_size = int(max_shard_size_gb * 1024**3)
|
||
|
_convert_shards(
|
||
|
ModelCheckpointRedistor(max_shard_size, [writer.save_model for writer in writers], param_count, redist_meta),
|
||
|
reader.load_models(), dist_meta_list)
|
||
|
_convert_shards(
|
||
|
OptimizerCheckpointRedistor(max_shard_size, [writer.save_optimizer for writer in writers], param_count,
|
||
|
param_to_os, paired_os, redist_meta), reader.load_optimizers(), dist_meta_list)
|
||
|
for writer, dist_meta in zip(writers, dist_metas):
|
||
|
meta_checkpoint = {'dist_meta': dist_meta, 'params': list(param_count.keys())}
|
||
|
if param_to_os is not None:
|
||
|
meta_checkpoint['param_to_os'] = param_to_os
|
||
|
meta_checkpoint['paired_os'] = paired_os
|
||
|
writer.save_meta(meta_checkpoint)
|
||
|
return True
|
||
|
|
||
|
|
||
|
def _convert_shards(convertor: CheckpointConvertor, shard_generator: Generator[dict, None, None],
|
||
|
dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
|
||
|
for shard_dict in shard_generator:
|
||
|
convertor.append(shard_dict, dist_meta_list)
|
||
|
convertor.complete()
|
||
|
|
||
|
|
||
|
def load(path: str,
|
||
|
model: Module,
|
||
|
optimizer: Optional[Optimizer] = None,
|
||
|
redist_meta: Optional[RedistMeta] = None,
|
||
|
dist_metas: Optional[List[Dict[str, ParamDistMeta]]] = None,
|
||
|
max_shard_size_gb: float = 0.0,
|
||
|
backend: str = 'disk') -> dict:
|
||
|
is_global: bool = not dist.is_initialized() or dist.get_world_size() == 1
|
||
|
rank: int = dist.get_rank() if dist.is_initialized() else 0
|
||
|
is_main_process: bool = rank == 0
|
||
|
# validate args
|
||
|
if redist_meta is None or dist_metas is None:
|
||
|
assert is_global
|
||
|
io_backend = get_backend(backend)
|
||
|
read_path: str = path
|
||
|
if is_main_process:
|
||
|
# pre-process checkpoints
|
||
|
temp_path = io_backend.get_temp(path)
|
||
|
if is_global:
|
||
|
wrote = merge(path, temp_path, max_shard_size_gb, backend=backend)
|
||
|
else:
|
||
|
wrote = redist(path, temp_path, redist_meta, dist_metas, max_shard_size_gb, backend=backend)
|
||
|
if wrote:
|
||
|
read_path = temp_path
|
||
|
if not is_global:
|
||
|
bcast_list = [read_path] if is_main_process else [None]
|
||
|
dist.broadcast_object_list(bcast_list)
|
||
|
read_path = bcast_list[0]
|
||
|
reader = io_backend.get_reader(read_path)
|
||
|
# load model
|
||
|
for shard in reader.load_model(rank):
|
||
|
model.load_state_dict(shard, strict=False)
|
||
|
if optimizer is not None:
|
||
|
for shard in reader.load_optimizer(rank):
|
||
|
# optimizer.load_state_dict(shard)
|
||
|
optimizer_load_state_dict(optimizer, shard)
|
||
|
others_dict = reader.load_others()
|
||
|
if not is_global:
|
||
|
dist.barrier()
|
||
|
# clean up temp
|
||
|
if is_main_process:
|
||
|
io_backend.clean_temp()
|
||
|
return others_dict
|