[model checkpoint] updated checkpoint save/load utils (#592)

アマデウス 2022-04-01 16:49:21 +08:00 committed by GitHub
parent 1c40ee8749
commit acae68eb04
No known key found for this signature in database
2 changed files with 218 additions and 176 deletions

View File

@ -1,9 +1,9 @@
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .activation_checkpoint import checkpoint
from .checkpointing import load_checkpoint, save_checkpoint
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0,
is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier,
ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage,
is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier,
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
from .data_sampler import DataParallelSampler, get_dataloader
@ -18,5 +18,6 @@ __all__ = [
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector'
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',

View File

@ -1,212 +1,253 @@
import os
import os.path as osp
import re
from typing import Tuple
from pathlib import Path
from collections import OrderedDict
from itertools import chain
import torch
from colossalai.context import Config
import torch.distributed as dist
from colossalai.communication.collective import scatter_object_list
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
__all__ = [
'get_checkpoint_path', 'get_latest_checkpoint_path', 'get_latest_checkpoint_pattern', 'save_checkpoint',
from .common import is_using_pp
__all__ = ["save_checkpoint", "load_checkpoint"]
def unwrap_config(config: Config):
"""Unwrap Config objects to normal dicts
config_dict = dict()
for k, v in config.items():
if isinstance(v, dict):
config_dict[k] = unwrap_config(v)
def broadcast_state_dict(state_dict, parallel_mode):
state_dict = [state_dict.copy() if isinstance(state_dict, dict) else state_dict]
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
dist.broadcast_object_list(state_dict, src=src_rank, group=gpc.get_cpu_group(parallel_mode))
return state_dict[0]
def partition_tensor_parallel_state_dict(
state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict()
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
depth = gpc.get_world_size(parallel_mode)
if gpc.get_local_rank(parallel_mode) == 0:
partitioned_state_list = [dict() for _ in range(depth)]
for key in list(state_dict.keys()):
param = state_dict.pop(key)
dim = dims.get(key, 0)
do_partition = partition_states.get(key, True)
if do_partition:
param = torch.chunk(param, depth, dim=dim)
for i, p in enumerate(partitioned_state_list):
p[key] = param[i] if do_partition else param
config_dict[k] = v
partitioned_state_list = [None for _ in range(depth)]
return config_dict
partitioned_state = [None]
scatter_object_list(partitioned_state, partitioned_state_list, src=src_rank, group=gpc.get_cpu_group(parallel_mode))
return partitioned_state[0]
def _get_ranks_name():
# tensor parallel
tp_local_rank = 0
if gpc.is_initialized(ParallelMode.TENSOR):
tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
def gather_tensor_parallel_state_dict(
state_dict: OrderedDict,
parallel_mode: ParallelMode,
dims: dict = dict(),
partition_states: dict = dict(),
keep_vars: bool = False,
dst_rank = gpc.get_ranks_in_group(parallel_mode)[0]
depth = gpc.get_world_size(parallel_mode)
# pipeline parallel
pp_local_rank = 0
if gpc.is_initialized(ParallelMode.PIPELINE):
pp_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
for key in list(state_dict.keys()):
param = state_dict.pop(key)
param = param if keep_vars else param.detach()
dim = dims.get(key, 0)
do_partition = partition_states.get(key, True)
if do_partition:
temp = param.transpose(0, dim).contiguous()
gather_list = None
if gpc.get_local_rank(parallel_mode) == 0:
shape = list(param.shape)
shape[0], shape[dim] = shape[dim], shape[0]
shape[0] *= depth
param = torch.empty(shape, dtype=param.dtype, device=param.device)
gather_list = list(torch.chunk(param, depth, dim=0))
dist.gather(temp, gather_list, dst=dst_rank, group=gpc.get_cpu_group(parallel_mode))
param = torch.transpose(param, 0, dim)
# update params in state_dict only on local rank 0
if gpc.get_local_rank(parallel_mode) == 0:
state_dict[key] = param
ranks_name = f'tp{tp_local_rank}-pp{pp_local_rank}'
return ranks_name
return state_dict
def _get_standard_checkpoint_filename(epoch: int, suffix: str = ''):
ranks_name = _get_ranks_name()
return f'epoch{epoch}-{ranks_name}{suffix}.pt'
def _send_state_dict(state_dict, dst, parallel_mode):
state_tensor, state_size = dist.distributed_c10d._object_to_tensor(state_dict)
dist.send(state_size, dst, group=gpc.get_cpu_group(parallel_mode))
dist.send(state_tensor, dst, group=gpc.get_cpu_group(parallel_mode))
def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''):
"""This is a function to generate the checkpoint path from the tuple
(checkpoint_dir, epoch, suffix, gpu_parallel_rank).
This is useful during generation and recuperation of the checkpoint.
checkpoint_dir (str): Set up a directory for saving checkpoints.
epoch (int): Epoch number (indicate how many epochs have you trained this model).
suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''
str: The checkpoint path to be generated.
ckpt_filename = _get_standard_checkpoint_filename(epoch, suffix)
return os.path.join(checkpoint_dir, ckpt_filename)
def _recv_state_dict(src, parallel_mode):
state_size = torch.tensor([0], dtype=torch.long)
dist.recv(state_size, src, group=gpc.get_cpu_group(parallel_mode))
state_tensor = torch.empty(state_size.item(), dtype=torch.uint8)
dist.recv(state_tensor, src, group=gpc.get_cpu_group(parallel_mode))
state_dict = dist.distributed_c10d._tensor_to_object(state_tensor, state_size)
return state_dict
def _ensure_directory_exists(filename: str):
# ensure the directory exists
dirpath = os.path.dirname(filename)
if not os.path.exists(dirpath):
Path(dirpath).mkdir(parents=True, exist_ok=True)
def partition_pipeline_parallel_state_dict(model, state_dict):
pipeline_state = OrderedDict()
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# receive all states from prev stage
if not gpc.is_first_rank(ParallelMode.PIPELINE):
state_dict = _recv_state_dict(gpc.get_prev_global_rank(ParallelMode.PIPELINE), ParallelMode.PIPELINE)
# move states to output
for name, _ in model.named_parameters(recurse=True):
if name in state_dict:
pipeline_state[name] = state_dict.pop(name)
for name, _ in model.named_buffers(recurse=True):
if name in state_dict:
pipeline_state[name] = state_dict.pop(name)
for name, _ in model.named_modules():
extra_state_key = name + "." + _EXTRA_STATE_KEY_SUFFIX
if extra_state_key in state_dict:
pipeline_state[extra_state_key] = state_dict.pop(extra_state_key)
# send rest states to next stage
if not gpc.is_last_rank(ParallelMode.PIPELINE):
_send_state_dict(state_dict, gpc.get_next_global_rank(ParallelMode.PIPELINE), ParallelMode.PIPELINE)
return pipeline_state
def get_latest_checkpoint_pattern(suffix: str = ''):
"""Generate Regular expression of the latest checkpoint's pattern.
def gather_pipeline_parallel_state_dict(state_dict):
gathered_states = (
[None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
else None
suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''.
state_dict = (
OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
else OrderedDict()
str: The regular expression of checkpoint pattern.
ranks_name = _get_ranks_name()
pattern = r'epoch(\d+)-{}{}\.pt'.format(ranks_name, suffix)
ckpt_pattern = re.compile(pattern)
return ckpt_pattern
return state_dict
def get_latest_checkpoint_path(checkpoint_dir: str, suffix: str = ''):
"""This is a function to retrieve the latest checkpoint path from the tuple
(checkpoint_dir, suffix, gpu_parallel_rank).
This is useful during recuperation of the checkpoint, especially when you do not know the epoch number.
checkpoint_dir (str): Directory for saving checkpoints
suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''
str: The latest retrieved checkpoint path.
FileNotFoundError: Raise error when we cannot find the latest checkpoint file with inputs given.
CKPT_NAME_PAT = get_latest_checkpoint_pattern(suffix=suffix)
last_epoch = -1
assert osp.isdir(checkpoint_dir), f'{checkpoint_dir} is not a directory'
for filename in os.listdir(checkpoint_dir):
ret = CKPT_NAME_PAT.match(filename)
if ret:
epoch = int(ret[0].split('-')[0].lstrip('epoch'))
if epoch > last_epoch:
last_epoch = epoch
if last_epoch == -1:
ranks_name = _get_ranks_name()
raise FileNotFoundError(f"Cannot find the latest checkpoint file for {ranks_name} in {checkpoint_dir}")
target_file = _get_standard_checkpoint_filename(last_epoch, suffix=suffix)
path = osp.join(checkpoint_dir, target_file)
return path
def save_checkpoint(checkpoint_path: str,
def save_checkpoint(
epoch: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
"""Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as
model, optimizer, lr_scheduler etc. into a checkpoint dictionary.
This method can be used for both :class:`colossalai.nn.BaseModel` and normal :class:`torch.nn.Module`.
"""Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer,
lr_scheduler etc. into a checkpoint dictionary.
checkpoint_path (str): Set up a directory for saving checkpoints.
epoch (int): Epoch number (indicate how many epochs have you trained this model).
model (:class:`torch.nn.Module`): Model to be registered.
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be registered.
file: a file-like object (has to implement write and flush) or a string or os.PathLike object containing a
file name.
epoch (int): Epoch number (indicates how many epochs have you trained this model).
model (:class:`torch.nn.Module`): Model to be saved.
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved.
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`,
:class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be registered, defaults to None.
kwargs (dict): additional parameters to be saved.
:class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be saved, defaults to None.
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
# for compatibility with normal pytorch nn.Module
if hasattr(model, 'state_dict_for_save_checkpoint'):
model_sd = model.state_dict_for_save_checkpoint()
model_sd = model.state_dict()
# ckpt container
checkpoint = {'epoch': epoch, 'model': model_sd, 'optimizer': optimizer.state_dict(), **kwargs}
if lr_scheduler is not None:
checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
checkpoint = {"epoch": epoch}
torch.save(checkpoint, checkpoint_path)
model_state = model.state_dict()
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
model_state = gather_pipeline_parallel_state_dict(model_state)
if gpc.get_global_rank() == 0:
checkpoint["model"] = model_state
# if optimizer is not None:
# checkpoint['optimizer'] = optimizer.state_dict()
# if lr_scheduler is not None:
# checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
torch.save(checkpoint, file, **kwargs)
def load_checkpoint(checkpoint_path: str,
def load_checkpoint(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
finetune: bool = False,
strict: bool = True) -> Tuple:
"""Loads the checkpoint file.
If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler)
and its descendants.
If finetune is True, then only the weights and buffers of model should be reloaded.
If strict is True, then the keys of state_dict must exactly match the keys returned
by this modules state_dict() function.
strict: bool = True,
"""Loads training states from a checkpoint file.
checkpoint_path (str): The exact and matched checkpoint_path directory to retrieve appropriate state_dict.
model (:class:`torch.nn.Module`): Model to reload parameters and buffers.
file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike
object containing a file name.
model (:class:`torch.nn.Module`): Model to load saved weights and buffers.
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to recuperate.
lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`, optional):
lr_scheduler to recuperate, defaults to None.
finetune (bool, optional): Whether to finetune the model with new dataset or
continue the pre-training, defaults to False.
strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict`
of the checkpoint match the names of parameters and buffers in model, defaults to True.
Tuple(int, ``checkpoint``): The tuple (the epoch number of the checkpoint retrieved, the checkpoint retrieved).
int: The saved epoch number.
ValueError: Raise error if the model/optimizer cannot successfully be recuperated
RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated
# Load the checkpoint.
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict = (
torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None
# model states
model_state = state_dict.pop("model") if state_dict is not None else dict()
# pipeline
if is_using_pp():
model_state = partition_pipeline_parallel_state_dict(model, model_state)
last_epoch = checkpoint.pop('epoch') if not finetune else 0
model.load_state_dict(checkpoint.pop('model'), strict=strict)
except KeyError:
raise ValueError('Checkpoint is corrupted')
model.load_state_dict(model_state, strict=strict)
except RuntimeError as e:
error_msgs = str(e)
if error_msgs.startswith("Error(s) in loading state_dict for "):
error_msgs = error_msgs.split("\n\t")[1:]
dst_rank = gpc.get_ranks_in_group(ParallelMode.MODEL)[0]
all_error_msgs = [None for _ in range(gpc.get_world_size(ParallelMode.MODEL))]
dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL))
if gpc.get_global_rank() == 0:
all_error_msgs = list(chain.from_iterable(all_error_msgs))
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
model.__class__.__name__, "\n\t".join(all_error_msgs)
raise e
if not finetune:
except KeyError:
raise ValueError('Checkpoint is corrupted')
# broadcast the rest states
state_dict = broadcast_state_dict(state_dict, ParallelMode.MODEL)
if lr_scheduler is not None and 'lr_scheduler' in checkpoint:
# # optimizer states
# if optimizer is not None and 'optimizer' in state_dict:
# optimizer.load_state_dict(state_dict['optimizer'])
return last_epoch, checkpoint
# # lr scheduler states
# if lr_scheduler is not None and 'lr_scheduler' in state_dict:
# lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
# last epoch
last_epoch = state_dict.pop("epoch", -1)
return last_epoch