mirror of https://github.com/hpcaitech/ColossalAI
[model checkpoint] updated checkpoint save/load utils (#592)
parent
1c40ee8749
commit
acae68eb04
|
@ -1,9 +1,9 @@
|
|||
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
|
||||
from .activation_checkpoint import checkpoint
|
||||
|
||||
from .checkpointing import load_checkpoint, save_checkpoint
|
||||
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
|
||||
free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0,
|
||||
is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier,
|
||||
ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage,
|
||||
is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier,
|
||||
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
|
||||
sync_model_param)
|
||||
from .data_sampler import DataParallelSampler, get_dataloader
|
||||
|
@ -18,5 +18,6 @@ __all__ = [
|
|||
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
|
||||
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
||||
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
|
||||
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector'
|
||||
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
|
||||
'ensure_path_exists'
|
||||
]
|
||||
|
|
|
@ -1,212 +1,253 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
import re
|
||||
from typing import Tuple
|
||||
from pathlib import Path
|
||||
from collections import OrderedDict
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.context import Config
|
||||
import torch.distributed as dist
|
||||
from colossalai.communication.collective import scatter_object_list
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
|
||||
__all__ = [
|
||||
'get_checkpoint_path', 'get_latest_checkpoint_path', 'get_latest_checkpoint_pattern', 'save_checkpoint',
|
||||
'load_checkpoint'
|
||||
]
|
||||
from .common import is_using_pp
|
||||
|
||||
__all__ = ["save_checkpoint", "load_checkpoint"]
|
||||
|
||||
|
||||
def unwrap_config(config: Config):
|
||||
"""Unwrap Config objects to normal dicts
|
||||
"""
|
||||
config_dict = dict()
|
||||
for k, v in config.items():
|
||||
if isinstance(v, dict):
|
||||
config_dict[k] = unwrap_config(v)
|
||||
else:
|
||||
config_dict[k] = v
|
||||
|
||||
return config_dict
|
||||
def broadcast_state_dict(state_dict, parallel_mode):
|
||||
state_dict = [state_dict.copy() if isinstance(state_dict, dict) else state_dict]
|
||||
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
|
||||
dist.broadcast_object_list(state_dict, src=src_rank, group=gpc.get_cpu_group(parallel_mode))
|
||||
return state_dict[0]
|
||||
|
||||
|
||||
def _get_ranks_name():
|
||||
# tensor parallel
|
||||
tp_local_rank = 0
|
||||
if gpc.is_initialized(ParallelMode.TENSOR):
|
||||
tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
def partition_tensor_parallel_state_dict(
|
||||
state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict()
|
||||
):
|
||||
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
|
||||
# pipeline parallel
|
||||
pp_local_rank = 0
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
pp_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
if gpc.get_local_rank(parallel_mode) == 0:
|
||||
|
||||
ranks_name = f'tp{tp_local_rank}-pp{pp_local_rank}'
|
||||
return ranks_name
|
||||
partitioned_state_list = [dict() for _ in range(depth)]
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
param = state_dict.pop(key)
|
||||
dim = dims.get(key, 0)
|
||||
do_partition = partition_states.get(key, True)
|
||||
if do_partition:
|
||||
param = torch.chunk(param, depth, dim=dim)
|
||||
for i, p in enumerate(partitioned_state_list):
|
||||
p[key] = param[i] if do_partition else param
|
||||
|
||||
def _get_standard_checkpoint_filename(epoch: int, suffix: str = ''):
|
||||
ranks_name = _get_ranks_name()
|
||||
return f'epoch{epoch}-{ranks_name}{suffix}.pt'
|
||||
|
||||
|
||||
def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''):
|
||||
"""This is a function to generate the checkpoint path from the tuple
|
||||
(checkpoint_dir, epoch, suffix, gpu_parallel_rank).
|
||||
This is useful during generation and recuperation of the checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): Set up a directory for saving checkpoints.
|
||||
epoch (int): Epoch number (indicate how many epochs have you trained this model).
|
||||
suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''
|
||||
|
||||
Returns:
|
||||
str: The checkpoint path to be generated.
|
||||
"""
|
||||
ckpt_filename = _get_standard_checkpoint_filename(epoch, suffix)
|
||||
return os.path.join(checkpoint_dir, ckpt_filename)
|
||||
|
||||
|
||||
def _ensure_directory_exists(filename: str):
|
||||
# ensure the directory exists
|
||||
dirpath = os.path.dirname(filename)
|
||||
if not os.path.exists(dirpath):
|
||||
Path(dirpath).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_latest_checkpoint_pattern(suffix: str = ''):
|
||||
"""Generate Regular expression of the latest checkpoint's pattern.
|
||||
|
||||
Args:
|
||||
suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''.
|
||||
|
||||
Returns:
|
||||
str: The regular expression of checkpoint pattern.
|
||||
"""
|
||||
ranks_name = _get_ranks_name()
|
||||
pattern = r'epoch(\d+)-{}{}\.pt'.format(ranks_name, suffix)
|
||||
ckpt_pattern = re.compile(pattern)
|
||||
return ckpt_pattern
|
||||
|
||||
|
||||
def get_latest_checkpoint_path(checkpoint_dir: str, suffix: str = ''):
|
||||
"""This is a function to retrieve the latest checkpoint path from the tuple
|
||||
(checkpoint_dir, suffix, gpu_parallel_rank).
|
||||
This is useful during recuperation of the checkpoint, especially when you do not know the epoch number.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): Directory for saving checkpoints
|
||||
suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''
|
||||
|
||||
Returns:
|
||||
str: The latest retrieved checkpoint path.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: Raise error when we cannot find the latest checkpoint file with inputs given.
|
||||
"""
|
||||
CKPT_NAME_PAT = get_latest_checkpoint_pattern(suffix=suffix)
|
||||
|
||||
last_epoch = -1
|
||||
assert osp.isdir(checkpoint_dir), f'{checkpoint_dir} is not a directory'
|
||||
|
||||
for filename in os.listdir(checkpoint_dir):
|
||||
ret = CKPT_NAME_PAT.match(filename)
|
||||
if ret:
|
||||
epoch = int(ret[0].split('-')[0].lstrip('epoch'))
|
||||
if epoch > last_epoch:
|
||||
last_epoch = epoch
|
||||
|
||||
if last_epoch == -1:
|
||||
ranks_name = _get_ranks_name()
|
||||
raise FileNotFoundError(f"Cannot find the latest checkpoint file for {ranks_name} in {checkpoint_dir}")
|
||||
else:
|
||||
target_file = _get_standard_checkpoint_filename(last_epoch, suffix=suffix)
|
||||
path = osp.join(checkpoint_dir, target_file)
|
||||
return path
|
||||
partitioned_state_list = [None for _ in range(depth)]
|
||||
|
||||
partitioned_state = [None]
|
||||
scatter_object_list(partitioned_state, partitioned_state_list, src=src_rank, group=gpc.get_cpu_group(parallel_mode))
|
||||
return partitioned_state[0]
|
||||
|
||||
|
||||
def save_checkpoint(checkpoint_path: str,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
**kwargs):
|
||||
"""Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as
|
||||
model, optimizer, lr_scheduler etc. into a checkpoint dictionary.
|
||||
def gather_tensor_parallel_state_dict(
|
||||
state_dict: OrderedDict,
|
||||
parallel_mode: ParallelMode,
|
||||
dims: dict = dict(),
|
||||
partition_states: dict = dict(),
|
||||
keep_vars: bool = False,
|
||||
):
|
||||
dst_rank = gpc.get_ranks_in_group(parallel_mode)[0]
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
|
||||
This method can be used for both :class:`colossalai.nn.BaseModel` and normal :class:`torch.nn.Module`.
|
||||
for key in list(state_dict.keys()):
|
||||
param = state_dict.pop(key)
|
||||
param = param if keep_vars else param.detach()
|
||||
dim = dims.get(key, 0)
|
||||
do_partition = partition_states.get(key, True)
|
||||
if do_partition:
|
||||
temp = param.transpose(0, dim).contiguous()
|
||||
gather_list = None
|
||||
if gpc.get_local_rank(parallel_mode) == 0:
|
||||
shape = list(param.shape)
|
||||
shape[0], shape[dim] = shape[dim], shape[0]
|
||||
shape[0] *= depth
|
||||
param = torch.empty(shape, dtype=param.dtype, device=param.device)
|
||||
gather_list = list(torch.chunk(param, depth, dim=0))
|
||||
dist.gather(temp, gather_list, dst=dst_rank, group=gpc.get_cpu_group(parallel_mode))
|
||||
param = torch.transpose(param, 0, dim)
|
||||
# update params in state_dict only on local rank 0
|
||||
if gpc.get_local_rank(parallel_mode) == 0:
|
||||
state_dict[key] = param
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def _send_state_dict(state_dict, dst, parallel_mode):
|
||||
state_tensor, state_size = dist.distributed_c10d._object_to_tensor(state_dict)
|
||||
dist.send(state_size, dst, group=gpc.get_cpu_group(parallel_mode))
|
||||
dist.send(state_tensor, dst, group=gpc.get_cpu_group(parallel_mode))
|
||||
|
||||
|
||||
def _recv_state_dict(src, parallel_mode):
|
||||
state_size = torch.tensor([0], dtype=torch.long)
|
||||
dist.recv(state_size, src, group=gpc.get_cpu_group(parallel_mode))
|
||||
state_tensor = torch.empty(state_size.item(), dtype=torch.uint8)
|
||||
dist.recv(state_tensor, src, group=gpc.get_cpu_group(parallel_mode))
|
||||
state_dict = dist.distributed_c10d._tensor_to_object(state_tensor, state_size)
|
||||
return state_dict
|
||||
|
||||
|
||||
def partition_pipeline_parallel_state_dict(model, state_dict):
|
||||
pipeline_state = OrderedDict()
|
||||
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# receive all states from prev stage
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
state_dict = _recv_state_dict(gpc.get_prev_global_rank(ParallelMode.PIPELINE), ParallelMode.PIPELINE)
|
||||
# move states to output
|
||||
for name, _ in model.named_parameters(recurse=True):
|
||||
if name in state_dict:
|
||||
pipeline_state[name] = state_dict.pop(name)
|
||||
for name, _ in model.named_buffers(recurse=True):
|
||||
if name in state_dict:
|
||||
pipeline_state[name] = state_dict.pop(name)
|
||||
for name, _ in model.named_modules():
|
||||
extra_state_key = name + "." + _EXTRA_STATE_KEY_SUFFIX
|
||||
if extra_state_key in state_dict:
|
||||
pipeline_state[extra_state_key] = state_dict.pop(extra_state_key)
|
||||
# send rest states to next stage
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
_send_state_dict(state_dict, gpc.get_next_global_rank(ParallelMode.PIPELINE), ParallelMode.PIPELINE)
|
||||
|
||||
return pipeline_state
|
||||
|
||||
|
||||
def gather_pipeline_parallel_state_dict(state_dict):
|
||||
gathered_states = (
|
||||
[None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
|
||||
else None
|
||||
)
|
||||
dist.gather_object(
|
||||
state_dict,
|
||||
gathered_states,
|
||||
dst=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[0],
|
||||
group=gpc.get_cpu_group(ParallelMode.PIPELINE),
|
||||
)
|
||||
|
||||
state_dict = (
|
||||
OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
|
||||
else OrderedDict()
|
||||
)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
file,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer,
|
||||
lr_scheduler etc. into a checkpoint dictionary.
|
||||
|
||||
Args:
|
||||
checkpoint_path (str): Set up a directory for saving checkpoints.
|
||||
epoch (int): Epoch number (indicate how many epochs have you trained this model).
|
||||
model (:class:`torch.nn.Module`): Model to be registered.
|
||||
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be registered.
|
||||
file: a file-like object (has to implement write and flush) or a string or os.PathLike object containing a
|
||||
file name.
|
||||
epoch (int): Epoch number (indicates how many epochs have you trained this model).
|
||||
model (:class:`torch.nn.Module`): Model to be saved.
|
||||
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved.
|
||||
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`,
|
||||
:class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be registered, defaults to None.
|
||||
kwargs (dict): additional parameters to be saved.
|
||||
:class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be saved, defaults to None.
|
||||
pickle_module: module used for pickling metadata and objects
|
||||
pickle_protocol: can be specified to override the default protocol
|
||||
"""
|
||||
# for compatibility with normal pytorch nn.Module
|
||||
if hasattr(model, 'state_dict_for_save_checkpoint'):
|
||||
model_sd = model.state_dict_for_save_checkpoint()
|
||||
else:
|
||||
model_sd = model.state_dict()
|
||||
|
||||
# ckpt container
|
||||
checkpoint = {'epoch': epoch, 'model': model_sd, 'optimizer': optimizer.state_dict(), **kwargs}
|
||||
if lr_scheduler is not None:
|
||||
checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
|
||||
checkpoint = {"epoch": epoch}
|
||||
|
||||
_ensure_directory_exists(checkpoint_path)
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
model_state = model.state_dict()
|
||||
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
model_state = gather_pipeline_parallel_state_dict(model_state)
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
checkpoint["model"] = model_state
|
||||
|
||||
# if optimizer is not None:
|
||||
# checkpoint['optimizer'] = optimizer.state_dict()
|
||||
|
||||
# if lr_scheduler is not None:
|
||||
# checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
|
||||
|
||||
torch.save(checkpoint, file, **kwargs)
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_path: str,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
finetune: bool = False,
|
||||
strict: bool = True) -> Tuple:
|
||||
"""Loads the checkpoint file.
|
||||
def load_checkpoint(
|
||||
file,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
strict: bool = True,
|
||||
):
|
||||
"""Loads training states from a checkpoint file.
|
||||
|
||||
If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
|
||||
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler)
|
||||
and its descendants.
|
||||
|
||||
If finetune is True, then only the weights and buffers of model should be reloaded.
|
||||
If strict is True, then the keys of state_dict must exactly match the keys returned
|
||||
by this module’s state_dict() function.
|
||||
|
||||
Args:
|
||||
checkpoint_path (str): The exact and matched checkpoint_path directory to retrieve appropriate state_dict.
|
||||
model (:class:`torch.nn.Module`): Model to reload parameters and buffers.
|
||||
Args:
|
||||
file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike
|
||||
object containing a file name.
|
||||
model (:class:`torch.nn.Module`): Model to load saved weights and buffers.
|
||||
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to recuperate.
|
||||
lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`, optional):
|
||||
lr_scheduler to recuperate, defaults to None.
|
||||
finetune (bool, optional): Whether to finetune the model with new dataset or
|
||||
continue the pre-training, defaults to False.
|
||||
strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict`
|
||||
of the checkpoint match the names of parameters and buffers in model, defaults to True.
|
||||
|
||||
Returns:
|
||||
Tuple(int, ``checkpoint``): The tuple (the epoch number of the checkpoint retrieved, the checkpoint retrieved).
|
||||
int: The saved epoch number.
|
||||
|
||||
Raises:
|
||||
ValueError: Raise error if the model/optimizer cannot successfully be recuperated
|
||||
RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated
|
||||
"""
|
||||
# Load the checkpoint.
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
state_dict = (
|
||||
torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None
|
||||
)
|
||||
|
||||
# model states
|
||||
model_state = state_dict.pop("model") if state_dict is not None else dict()
|
||||
# pipeline
|
||||
if is_using_pp():
|
||||
model_state = partition_pipeline_parallel_state_dict(model, model_state)
|
||||
try:
|
||||
last_epoch = checkpoint.pop('epoch') if not finetune else 0
|
||||
model.load_state_dict(checkpoint.pop('model'), strict=strict)
|
||||
except KeyError:
|
||||
raise ValueError('Checkpoint is corrupted')
|
||||
model.load_state_dict(model_state, strict=strict)
|
||||
except RuntimeError as e:
|
||||
error_msgs = str(e)
|
||||
if error_msgs.startswith("Error(s) in loading state_dict for "):
|
||||
error_msgs = error_msgs.split("\n\t")[1:]
|
||||
dst_rank = gpc.get_ranks_in_group(ParallelMode.MODEL)[0]
|
||||
all_error_msgs = [None for _ in range(gpc.get_world_size(ParallelMode.MODEL))]
|
||||
dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL))
|
||||
if gpc.get_global_rank() == 0:
|
||||
all_error_msgs = list(chain.from_iterable(all_error_msgs))
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||
model.__class__.__name__, "\n\t".join(all_error_msgs)
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
if not finetune:
|
||||
try:
|
||||
optimizer.load_state_dict(checkpoint.pop('optimizer'))
|
||||
except KeyError:
|
||||
raise ValueError('Checkpoint is corrupted')
|
||||
# broadcast the rest states
|
||||
state_dict = broadcast_state_dict(state_dict, ParallelMode.MODEL)
|
||||
|
||||
if lr_scheduler is not None and 'lr_scheduler' in checkpoint:
|
||||
lr_scheduler.load_state_dict(checkpoint.pop('lr_scheduler'))
|
||||
# # optimizer states
|
||||
# if optimizer is not None and 'optimizer' in state_dict:
|
||||
# optimizer.load_state_dict(state_dict['optimizer'])
|
||||
|
||||
return last_epoch, checkpoint
|
||||
# # lr scheduler states
|
||||
# if lr_scheduler is not None and 'lr_scheduler' in state_dict:
|
||||
# lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
|
||||
|
||||
# last epoch
|
||||
last_epoch = state_dict.pop("epoch", -1)
|
||||
|
||||
return last_epoch
|
||||
|
|
Loading…
Reference in New Issue