ColossalAI/colossalai/utils/checkpointing.py

213 lines
8.2 KiB
Python
Raw Normal View History

2021-10-28 16:21:23 +00:00
import os
import os.path as osp
import re
from typing import Tuple
from pathlib import Path
2021-10-28 16:21:23 +00:00
import torch
from colossalai.context import Config
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
2021-10-28 16:21:23 +00:00
__all__ = [
'get_checkpoint_path', 'get_latest_checkpoint_path', 'get_latest_checkpoint_pattern', 'save_checkpoint',
2021-10-28 16:21:23 +00:00
'load_checkpoint'
]
def unwrap_config(config: Config):
2022-01-21 02:44:30 +00:00
"""Unwrap Config objects to normal dicts
"""
2021-10-28 16:21:23 +00:00
config_dict = dict()
for k, v in config.items():
if isinstance(v, dict):
config_dict[k] = unwrap_config(v)
else:
config_dict[k] = v
return config_dict
def _get_ranks_name():
# tensor parallel
tp_local_rank = 0
if gpc.is_initialized(ParallelMode.TENSOR):
tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
# pipeline parallel
pp_local_rank = 0
if gpc.is_initialized(ParallelMode.PIPELINE):
pp_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
ranks_name = f'tp{tp_local_rank}-pp{pp_local_rank}'
return ranks_name
def _get_standard_checkpoint_filename(epoch: int, suffix: str = ''):
ranks_name = _get_ranks_name()
return f'epoch{epoch}-{ranks_name}{suffix}.pt'
def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''):
2022-03-25 05:02:39 +00:00
"""This is a function to generate the checkpoint path from the tuple
(checkpoint_dir, epoch, suffix, gpu_parallel_rank).
2021-10-28 16:21:23 +00:00
This is useful during generation and recuperation of the checkpoint.
2022-03-25 05:02:39 +00:00
Args:
checkpoint_dir (str): Set up a directory for saving checkpoints.
epoch (int): Epoch number (indicate how many epochs have you trained this model).
suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''
Returns:
str: The checkpoint path to be generated.
2022-01-21 02:44:30 +00:00
"""
2021-10-28 16:21:23 +00:00
ckpt_filename = _get_standard_checkpoint_filename(epoch, suffix)
return os.path.join(checkpoint_dir, ckpt_filename)
def _ensure_directory_exists(filename: str):
# ensure the directory exists
dirpath = os.path.dirname(filename)
if not os.path.exists(dirpath):
Path(dirpath).mkdir(parents=True, exist_ok=True)
2021-10-28 16:21:23 +00:00
def get_latest_checkpoint_pattern(suffix: str = ''):
2022-03-25 05:02:39 +00:00
"""Generate Regular expression of the latest checkpoint's pattern.
Args:
suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''.
2021-10-28 16:21:23 +00:00
2022-03-25 05:02:39 +00:00
Returns:
str: The regular expression of checkpoint pattern.
2022-01-21 02:44:30 +00:00
"""
2021-10-28 16:21:23 +00:00
ranks_name = _get_ranks_name()
pattern = r'epoch(\d+)-{}{}\.pt'.format(ranks_name, suffix)
ckpt_pattern = re.compile(pattern)
2021-10-28 16:21:23 +00:00
return ckpt_pattern
def get_latest_checkpoint_path(checkpoint_dir: str, suffix: str = ''):
2022-03-25 05:02:39 +00:00
"""This is a function to retrieve the latest checkpoint path from the tuple
(checkpoint_dir, suffix, gpu_parallel_rank).
2021-10-28 16:21:23 +00:00
This is useful during recuperation of the checkpoint, especially when you do not know the epoch number.
2022-03-25 05:02:39 +00:00
Args:
checkpoint_dir (str): Directory for saving checkpoints
suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''
Returns:
str: The latest retrieved checkpoint path.
Raises:
FileNotFoundError: Raise error when we cannot find the latest checkpoint file with inputs given.
2022-01-21 02:44:30 +00:00
"""
2021-10-28 16:21:23 +00:00
CKPT_NAME_PAT = get_latest_checkpoint_pattern(suffix=suffix)
last_epoch = -1
assert osp.isdir(checkpoint_dir), f'{checkpoint_dir} is not a directory'
for filename in os.listdir(checkpoint_dir):
ret = CKPT_NAME_PAT.match(filename)
if ret:
epoch = int(ret[0].split('-')[0].lstrip('epoch'))
if epoch > last_epoch:
last_epoch = epoch
if last_epoch == -1:
ranks_name = _get_ranks_name()
raise FileNotFoundError(f"Cannot find the latest checkpoint file for {ranks_name} in {checkpoint_dir}")
else:
target_file = _get_standard_checkpoint_filename(last_epoch, suffix=suffix)
path = osp.join(checkpoint_dir, target_file)
return path
def save_checkpoint(checkpoint_path: str,
epoch: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
**kwargs):
2022-03-25 05:02:39 +00:00
"""Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as
model, optimizer, lr_scheduler etc. into a checkpoint dictionary.
This method can be used for both :class:`colossalai.nn.BaseModel` and normal :class:`torch.nn.Module`.
Args:
checkpoint_path (str): Set up a directory for saving checkpoints.
epoch (int): Epoch number (indicate how many epochs have you trained this model).
model (:class:`torch.nn.Module`): Model to be registered.
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be registered.
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`,
:class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be registered, defaults to None.
kwargs (dict): additional parameters to be saved.
2022-01-21 02:44:30 +00:00
"""
2021-10-28 16:21:23 +00:00
# for compatibility with normal pytorch nn.Module
if hasattr(model, 'state_dict_for_save_checkpoint'):
model_sd = model.state_dict_for_save_checkpoint()
else:
model_sd = model.state_dict()
# ckpt container
checkpoint = {'epoch': epoch, 'model': model_sd, 'optimizer': optimizer.state_dict(), **kwargs}
2021-10-28 16:21:23 +00:00
if lr_scheduler is not None:
checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
_ensure_directory_exists(checkpoint_path)
torch.save(checkpoint, checkpoint_path)
def load_checkpoint(checkpoint_path: str,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
finetune: bool = False,
strict: bool = True) -> Tuple:
2022-01-21 02:44:30 +00:00
"""Loads the checkpoint file.
2022-03-25 05:02:39 +00:00
2022-01-21 02:44:30 +00:00
If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler)
2022-03-25 05:02:39 +00:00
and its descendants.
If finetune is True, then only the weights and buffers of model should be reloaded.
If strict is True, then the keys of state_dict must exactly match the keys returned
by this modules state_dict() function.
Args:
checkpoint_path (str): The exact and matched checkpoint_path directory to retrieve appropriate state_dict.
model (:class:`torch.nn.Module`): Model to reload parameters and buffers.
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to recuperate.
lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`, optional):
lr_scheduler to recuperate, defaults to None.
finetune (bool, optional): Whether to finetune the model with new dataset or
continue the pre-training, defaults to False.
strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict`
of the checkpoint match the names of parameters and buffers in model, defaults to True.
Returns:
Tuple(int, ``checkpoint``): The tuple (the epoch number of the checkpoint retrieved, the checkpoint retrieved).
Raises:
ValueError: Raise error if the model/optimizer cannot successfully be recuperated
2022-01-21 02:44:30 +00:00
"""
2021-10-28 16:21:23 +00:00
# Load the checkpoint.
checkpoint = torch.load(checkpoint_path, map_location='cpu')
try:
last_epoch = checkpoint.pop('epoch') if not finetune else 0
model.load_state_dict(checkpoint.pop('model'), strict=strict)
except KeyError:
raise ValueError('Checkpoint is corrupted')
if not finetune:
try:
optimizer.load_state_dict(checkpoint.pop('optimizer'))
except KeyError:
raise ValueError('Checkpoint is corrupted')
if lr_scheduler is not None and 'lr_scheduler' in checkpoint:
lr_scheduler.load_state_dict(checkpoint.pop('lr_scheduler'))
return last_epoch, checkpoint