You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/utils/checkpointing.py

215 lines
8.1 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import os.path as osp
import re
from typing import Tuple
import torch
from colossalai.context import Config
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
__all__ = [
'get_checkpoint_path',
'get_latest_checkpoint_path',
'get_latest_checkpoint_pattern',
'save_checkpoint',
'load_checkpoint'
]
def unwrap_config(config: Config):
"""Unwrap Config objects to normal dicts
"""
config_dict = dict()
for k, v in config.items():
if isinstance(v, dict):
config_dict[k] = unwrap_config(v)
else:
config_dict[k] = v
return config_dict
def _get_ranks_name():
# tensor parallel
tp_local_rank = 0
if gpc.is_initialized(ParallelMode.TENSOR):
tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
# pipeline parallel
pp_local_rank = 0
if gpc.is_initialized(ParallelMode.PIPELINE):
pp_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
ranks_name = f'tp{tp_local_rank}-pp{pp_local_rank}'
return ranks_name
def _get_standard_checkpoint_filename(epoch: int, suffix: str = ''):
ranks_name = _get_ranks_name()
return f'epoch{epoch}-{ranks_name}{suffix}.pt'
def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''):
"""This is a function to generate the checkpoint path from the (checkpoint_dir, epoch, suffix, gpu_parallel_rank) tuple.
This is useful during generation and recuperation of the checkpoint.
:param checkpoint_dir: Set up a directory for saving checkpoints
:type checkpoint_dir: str
:param epoch: Epoch number (indicate how many epochs have you trained this model)
:type epoch: int
:param suffix: Additional notation to specify the model or checkpoint, defaults to ''
:type suffix: str, optional
:return: Checkpoint path to be generated
:rtype: path
"""
ckpt_filename = _get_standard_checkpoint_filename(epoch, suffix)
return os.path.join(checkpoint_dir, ckpt_filename)
def _ensure_directory_exists(filename: str):
# ensure the directory exists
dir = os.path.dirname(filename)
if not os.path.exists(dir):
os.makedirs(dir)
def get_latest_checkpoint_pattern(suffix: str = ''):
"""Generate Regular expression of latest checkpoint's pattern
:param suffix: Additional notation to specify the model or checkpoint, defaults to ''
:type suffix: str, optional
:return: Checkpoint pattern
:rtype: regular expression
"""
ranks_name = _get_ranks_name()
ckpt_pattern = re.compile(f'epoch(\d+)-{ranks_name}{suffix}\.pt')
return ckpt_pattern
def get_latest_checkpoint_path(checkpoint_dir: str, suffix: str = ''):
"""This is a function to retrieve the latest checkpoint path from the (checkpoint_dir, suffix, gpu_parallel_rank) tuple.
This is useful during recuperation of the checkpoint, especially when you do not know the epoch number.
:param checkpoint_dir: Directory for saving checkpoints
:type checkpoint_dir: str
:param suffix: Additional notation to specify the model or checkpoint, defaults to ''
:type suffix: str, optional
:raises FileNotFoundError: Raise error when we cannot find the latest checkpoint file with inputs given
:return: The latest checkpoint path to be retrieved
:rtype: path
"""
CKPT_NAME_PAT = get_latest_checkpoint_pattern(suffix=suffix)
last_epoch = -1
assert osp.isdir(checkpoint_dir), f'{checkpoint_dir} is not a directory'
for filename in os.listdir(checkpoint_dir):
ret = CKPT_NAME_PAT.match(filename)
if ret:
epoch = int(ret[0].split('-')[0].lstrip('epoch'))
if epoch > last_epoch:
last_epoch = epoch
if last_epoch == -1:
ranks_name = _get_ranks_name()
raise FileNotFoundError(f"Cannot find the latest checkpoint file for {ranks_name} in {checkpoint_dir}")
else:
target_file = _get_standard_checkpoint_filename(last_epoch, suffix=suffix)
path = osp.join(checkpoint_dir, target_file)
return path
def save_checkpoint(checkpoint_path: str,
epoch: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
**kwargs):
"""Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model, optimizer, lr_scheduler and etc. into a checkpoint dictionary.
This method can be used for both colosalai nn.BaseModel and normal pytorch nn.Module.
:param checkpoint_path: Set up a directory for saving checkpoints
:type checkpoint_path: str
:param epoch: Epoch number (indicate how many epochs have you trained this model)
:type epoch: int
:param model: Model to be registered
:type model: torch.nn.Module
:param optimizer: Optimizer to be registered
:type optimizer: torch.optim.Optimizer
:param lr_scheduler: lr_scheduler to be registered, defaults to None
:type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional
"""
# for compatibility with normal pytorch nn.Module
if hasattr(model, 'state_dict_for_save_checkpoint'):
model_sd = model.state_dict_for_save_checkpoint()
else:
model_sd = model.state_dict()
# ckpt container
checkpoint = {
'epoch': epoch,
'model': model_sd,
'optimizer': optimizer.state_dict(),
**kwargs
}
if lr_scheduler is not None:
checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
_ensure_directory_exists(checkpoint_path)
torch.save(checkpoint, checkpoint_path)
def load_checkpoint(checkpoint_path: str,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
finetune: bool = False,
strict: bool = True) -> Tuple:
"""Loads the checkpoint file.
If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler) and its descendants.
If finetune is True, then only the weights and buffers of model should be reload.
If strict is True, then the keys of state_dict must exactly match the keys returned by this modules state_dict() function.
:param checkpoint_path: The exact and matched checkpoint_path directory to retrieve appropriate state_dict
:type checkpoint_path: str
:param model: Model to reload parameters and buffers
:type model: torch.nn.Module
:param optimizer: Optimizer to recuperate
:type optimizer: torch.optim.Optimizer
:param lr_scheduler: lr_scheduler to recuperate, defaults to None
:type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional
:param finetune: Whether to finetune the model with new dataset or continue the pre-training, defaults to False
:type finetune: bool, optional
:param strict: Whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
parameters and buffers in model., defaults to True
:type strict: bool, optional
:raises ValueError: Raise error if the model/optimizer cannot successfully be recuperated
:return: (the epoch number of the checkpoint retrieved, the checkpoint retrieved)
:rtype: Tuple
"""
# Load the checkpoint.
checkpoint = torch.load(checkpoint_path, map_location='cpu')
try:
last_epoch = checkpoint.pop('epoch') if not finetune else 0
model.load_state_dict(checkpoint.pop('model'), strict=strict)
except KeyError:
raise ValueError('Checkpoint is corrupted')
if not finetune:
try:
optimizer.load_state_dict(checkpoint.pop('optimizer'))
except KeyError:
raise ValueError('Checkpoint is corrupted')
if lr_scheduler is not None and 'lr_scheduler' in checkpoint:
lr_scheduler.load_state_dict(checkpoint.pop('lr_scheduler'))
return last_epoch, checkpoint