[model checkpoint] updated checkpoint save/load utils (#592)

pull/625/head
アマデウス 2022-04-01 16:49:21 +08:00 committed by GitHub
parent 1c40ee8749
commit acae68eb04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 218 additions and 176 deletions

View File

@ -1,9 +1,9 @@
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .activation_checkpoint import checkpoint from .activation_checkpoint import checkpoint
from .checkpointing import load_checkpoint, save_checkpoint
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32, from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0, ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage,
is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier, is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier,
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank, param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
sync_model_param) sync_model_param)
from .data_sampler import DataParallelSampler, get_dataloader from .data_sampler import DataParallelSampler, get_dataloader
@ -18,5 +18,6 @@ __all__ = [
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes', 'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda', 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler', 'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector' 'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
'ensure_path_exists'
] ]

View File

@ -1,212 +1,253 @@
import os from collections import OrderedDict
import os.path as osp from itertools import chain
import re
from typing import Tuple
from pathlib import Path
import torch import torch
import torch.distributed as dist
from colossalai.context import Config from colossalai.communication.collective import scatter_object_list
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
__all__ = [ from .common import is_using_pp
'get_checkpoint_path', 'get_latest_checkpoint_path', 'get_latest_checkpoint_pattern', 'save_checkpoint',
'load_checkpoint' __all__ = ["save_checkpoint", "load_checkpoint"]
]
def unwrap_config(config: Config): def broadcast_state_dict(state_dict, parallel_mode):
"""Unwrap Config objects to normal dicts state_dict = [state_dict.copy() if isinstance(state_dict, dict) else state_dict]
""" src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
config_dict = dict() dist.broadcast_object_list(state_dict, src=src_rank, group=gpc.get_cpu_group(parallel_mode))
for k, v in config.items(): return state_dict[0]
if isinstance(v, dict):
config_dict[k] = unwrap_config(v)
def partition_tensor_parallel_state_dict(
state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict()
):
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
depth = gpc.get_world_size(parallel_mode)
if gpc.get_local_rank(parallel_mode) == 0:
partitioned_state_list = [dict() for _ in range(depth)]
for key in list(state_dict.keys()):
param = state_dict.pop(key)
dim = dims.get(key, 0)
do_partition = partition_states.get(key, True)
if do_partition:
param = torch.chunk(param, depth, dim=dim)
for i, p in enumerate(partitioned_state_list):
p[key] = param[i] if do_partition else param
else: else:
config_dict[k] = v partitioned_state_list = [None for _ in range(depth)]
return config_dict partitioned_state = [None]
scatter_object_list(partitioned_state, partitioned_state_list, src=src_rank, group=gpc.get_cpu_group(parallel_mode))
return partitioned_state[0]
def _get_ranks_name(): def gather_tensor_parallel_state_dict(
# tensor parallel state_dict: OrderedDict,
tp_local_rank = 0 parallel_mode: ParallelMode,
if gpc.is_initialized(ParallelMode.TENSOR): dims: dict = dict(),
tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR) partition_states: dict = dict(),
keep_vars: bool = False,
):
dst_rank = gpc.get_ranks_in_group(parallel_mode)[0]
depth = gpc.get_world_size(parallel_mode)
# pipeline parallel for key in list(state_dict.keys()):
pp_local_rank = 0 param = state_dict.pop(key)
if gpc.is_initialized(ParallelMode.PIPELINE): param = param if keep_vars else param.detach()
pp_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) dim = dims.get(key, 0)
do_partition = partition_states.get(key, True)
if do_partition:
temp = param.transpose(0, dim).contiguous()
gather_list = None
if gpc.get_local_rank(parallel_mode) == 0:
shape = list(param.shape)
shape[0], shape[dim] = shape[dim], shape[0]
shape[0] *= depth
param = torch.empty(shape, dtype=param.dtype, device=param.device)
gather_list = list(torch.chunk(param, depth, dim=0))
dist.gather(temp, gather_list, dst=dst_rank, group=gpc.get_cpu_group(parallel_mode))
param = torch.transpose(param, 0, dim)
# update params in state_dict only on local rank 0
if gpc.get_local_rank(parallel_mode) == 0:
state_dict[key] = param
ranks_name = f'tp{tp_local_rank}-pp{pp_local_rank}' return state_dict
return ranks_name
def _get_standard_checkpoint_filename(epoch: int, suffix: str = ''): def _send_state_dict(state_dict, dst, parallel_mode):
ranks_name = _get_ranks_name() state_tensor, state_size = dist.distributed_c10d._object_to_tensor(state_dict)
return f'epoch{epoch}-{ranks_name}{suffix}.pt' dist.send(state_size, dst, group=gpc.get_cpu_group(parallel_mode))
dist.send(state_tensor, dst, group=gpc.get_cpu_group(parallel_mode))
def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''): def _recv_state_dict(src, parallel_mode):
"""This is a function to generate the checkpoint path from the tuple state_size = torch.tensor([0], dtype=torch.long)
(checkpoint_dir, epoch, suffix, gpu_parallel_rank). dist.recv(state_size, src, group=gpc.get_cpu_group(parallel_mode))
This is useful during generation and recuperation of the checkpoint. state_tensor = torch.empty(state_size.item(), dtype=torch.uint8)
dist.recv(state_tensor, src, group=gpc.get_cpu_group(parallel_mode))
Args: state_dict = dist.distributed_c10d._tensor_to_object(state_tensor, state_size)
checkpoint_dir (str): Set up a directory for saving checkpoints. return state_dict
epoch (int): Epoch number (indicate how many epochs have you trained this model).
suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''
Returns:
str: The checkpoint path to be generated.
"""
ckpt_filename = _get_standard_checkpoint_filename(epoch, suffix)
return os.path.join(checkpoint_dir, ckpt_filename)
def _ensure_directory_exists(filename: str): def partition_pipeline_parallel_state_dict(model, state_dict):
# ensure the directory exists pipeline_state = OrderedDict()
dirpath = os.path.dirname(filename)
if not os.path.exists(dirpath): if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
Path(dirpath).mkdir(parents=True, exist_ok=True) # receive all states from prev stage
if not gpc.is_first_rank(ParallelMode.PIPELINE):
state_dict = _recv_state_dict(gpc.get_prev_global_rank(ParallelMode.PIPELINE), ParallelMode.PIPELINE)
# move states to output
for name, _ in model.named_parameters(recurse=True):
if name in state_dict:
pipeline_state[name] = state_dict.pop(name)
for name, _ in model.named_buffers(recurse=True):
if name in state_dict:
pipeline_state[name] = state_dict.pop(name)
for name, _ in model.named_modules():
extra_state_key = name + "." + _EXTRA_STATE_KEY_SUFFIX
if extra_state_key in state_dict:
pipeline_state[extra_state_key] = state_dict.pop(extra_state_key)
# send rest states to next stage
if not gpc.is_last_rank(ParallelMode.PIPELINE):
_send_state_dict(state_dict, gpc.get_next_global_rank(ParallelMode.PIPELINE), ParallelMode.PIPELINE)
return pipeline_state
def get_latest_checkpoint_pattern(suffix: str = ''): def gather_pipeline_parallel_state_dict(state_dict):
"""Generate Regular expression of the latest checkpoint's pattern. gathered_states = (
[None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
else None
)
dist.gather_object(
state_dict,
gathered_states,
dst=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[0],
group=gpc.get_cpu_group(ParallelMode.PIPELINE),
)
Args: state_dict = (
suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''. OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
else OrderedDict()
)
Returns: return state_dict
str: The regular expression of checkpoint pattern.
"""
ranks_name = _get_ranks_name()
pattern = r'epoch(\d+)-{}{}\.pt'.format(ranks_name, suffix)
ckpt_pattern = re.compile(pattern)
return ckpt_pattern
def get_latest_checkpoint_path(checkpoint_dir: str, suffix: str = ''): def save_checkpoint(
"""This is a function to retrieve the latest checkpoint path from the tuple file,
(checkpoint_dir, suffix, gpu_parallel_rank).
This is useful during recuperation of the checkpoint, especially when you do not know the epoch number.
Args:
checkpoint_dir (str): Directory for saving checkpoints
suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''
Returns:
str: The latest retrieved checkpoint path.
Raises:
FileNotFoundError: Raise error when we cannot find the latest checkpoint file with inputs given.
"""
CKPT_NAME_PAT = get_latest_checkpoint_pattern(suffix=suffix)
last_epoch = -1
assert osp.isdir(checkpoint_dir), f'{checkpoint_dir} is not a directory'
for filename in os.listdir(checkpoint_dir):
ret = CKPT_NAME_PAT.match(filename)
if ret:
epoch = int(ret[0].split('-')[0].lstrip('epoch'))
if epoch > last_epoch:
last_epoch = epoch
if last_epoch == -1:
ranks_name = _get_ranks_name()
raise FileNotFoundError(f"Cannot find the latest checkpoint file for {ranks_name} in {checkpoint_dir}")
else:
target_file = _get_standard_checkpoint_filename(last_epoch, suffix=suffix)
path = osp.join(checkpoint_dir, target_file)
return path
def save_checkpoint(checkpoint_path: str,
epoch: int, epoch: int,
model: torch.nn.Module, model: torch.nn.Module,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
**kwargs): **kwargs
"""Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as ):
model, optimizer, lr_scheduler etc. into a checkpoint dictionary. """Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer,
lr_scheduler etc. into a checkpoint dictionary.
This method can be used for both :class:`colossalai.nn.BaseModel` and normal :class:`torch.nn.Module`.
Args: Args:
checkpoint_path (str): Set up a directory for saving checkpoints. file: a file-like object (has to implement write and flush) or a string or os.PathLike object containing a
epoch (int): Epoch number (indicate how many epochs have you trained this model). file name.
model (:class:`torch.nn.Module`): Model to be registered. epoch (int): Epoch number (indicates how many epochs have you trained this model).
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be registered. model (:class:`torch.nn.Module`): Model to be saved.
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved.
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, lr_scheduler (Union[:class:`torch.optim.lr_scheduler`,
:class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be registered, defaults to None. :class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be saved, defaults to None.
kwargs (dict): additional parameters to be saved. pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
""" """
# for compatibility with normal pytorch nn.Module
if hasattr(model, 'state_dict_for_save_checkpoint'):
model_sd = model.state_dict_for_save_checkpoint()
else:
model_sd = model.state_dict()
# ckpt container # ckpt container
checkpoint = {'epoch': epoch, 'model': model_sd, 'optimizer': optimizer.state_dict(), **kwargs} checkpoint = {"epoch": epoch}
if lr_scheduler is not None:
checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
_ensure_directory_exists(checkpoint_path) model_state = model.state_dict()
torch.save(checkpoint, checkpoint_path) if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
model_state = gather_pipeline_parallel_state_dict(model_state)
if gpc.get_global_rank() == 0:
checkpoint["model"] = model_state
# if optimizer is not None:
# checkpoint['optimizer'] = optimizer.state_dict()
# if lr_scheduler is not None:
# checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
torch.save(checkpoint, file, **kwargs)
def load_checkpoint(checkpoint_path: str, def load_checkpoint(
file,
model: torch.nn.Module, model: torch.nn.Module,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
finetune: bool = False, strict: bool = True,
strict: bool = True) -> Tuple: ):
"""Loads the checkpoint file. """Loads training states from a checkpoint file.
If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler)
and its descendants.
If finetune is True, then only the weights and buffers of model should be reloaded.
If strict is True, then the keys of state_dict must exactly match the keys returned
by this modules state_dict() function.
Args: Args:
checkpoint_path (str): The exact and matched checkpoint_path directory to retrieve appropriate state_dict. file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike
model (:class:`torch.nn.Module`): Model to reload parameters and buffers. object containing a file name.
model (:class:`torch.nn.Module`): Model to load saved weights and buffers.
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to recuperate. optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to recuperate.
lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`, optional): lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`, optional):
lr_scheduler to recuperate, defaults to None. lr_scheduler to recuperate, defaults to None.
finetune (bool, optional): Whether to finetune the model with new dataset or
continue the pre-training, defaults to False.
strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict` strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict`
of the checkpoint match the names of parameters and buffers in model, defaults to True. of the checkpoint match the names of parameters and buffers in model, defaults to True.
Returns: Returns:
Tuple(int, ``checkpoint``): The tuple (the epoch number of the checkpoint retrieved, the checkpoint retrieved). int: The saved epoch number.
Raises: Raises:
ValueError: Raise error if the model/optimizer cannot successfully be recuperated RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated
""" """
# Load the checkpoint. state_dict = (
checkpoint = torch.load(checkpoint_path, map_location='cpu') torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None
)
# model states
model_state = state_dict.pop("model") if state_dict is not None else dict()
# pipeline
if is_using_pp():
model_state = partition_pipeline_parallel_state_dict(model, model_state)
try: try:
last_epoch = checkpoint.pop('epoch') if not finetune else 0 model.load_state_dict(model_state, strict=strict)
model.load_state_dict(checkpoint.pop('model'), strict=strict) except RuntimeError as e:
except KeyError: error_msgs = str(e)
raise ValueError('Checkpoint is corrupted') if error_msgs.startswith("Error(s) in loading state_dict for "):
error_msgs = error_msgs.split("\n\t")[1:]
dst_rank = gpc.get_ranks_in_group(ParallelMode.MODEL)[0]
all_error_msgs = [None for _ in range(gpc.get_world_size(ParallelMode.MODEL))]
dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL))
if gpc.get_global_rank() == 0:
all_error_msgs = list(chain.from_iterable(all_error_msgs))
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
model.__class__.__name__, "\n\t".join(all_error_msgs)
)
)
else:
raise e
if not finetune: # broadcast the rest states
try: state_dict = broadcast_state_dict(state_dict, ParallelMode.MODEL)
optimizer.load_state_dict(checkpoint.pop('optimizer'))
except KeyError:
raise ValueError('Checkpoint is corrupted')
if lr_scheduler is not None and 'lr_scheduler' in checkpoint: # # optimizer states
lr_scheduler.load_state_dict(checkpoint.pop('lr_scheduler')) # if optimizer is not None and 'optimizer' in state_dict:
# optimizer.load_state_dict(state_dict['optimizer'])
return last_epoch, checkpoint # # lr scheduler states
# if lr_scheduler is not None and 'lr_scheduler' in state_dict:
# lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
# last epoch
last_epoch = state_dict.pop("epoch", -1)
return last_epoch