2021-10-28 16:21:23 +00:00
|
|
|
|
import os
|
|
|
|
|
import os.path as osp
|
|
|
|
|
import re
|
|
|
|
|
from typing import Tuple
|
2022-02-14 08:19:24 +00:00
|
|
|
|
from pathlib import Path
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
2021-11-18 11:45:06 +00:00
|
|
|
|
from colossalai.context import Config
|
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
|
|
|
from colossalai.core import global_context as gpc
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
|
|
__all__ = [
|
2022-02-14 08:19:24 +00:00
|
|
|
|
'get_checkpoint_path', 'get_latest_checkpoint_path', 'get_latest_checkpoint_pattern', 'save_checkpoint',
|
2021-10-28 16:21:23 +00:00
|
|
|
|
'load_checkpoint'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unwrap_config(config: Config):
|
2022-01-21 02:44:30 +00:00
|
|
|
|
"""Unwrap Config objects to normal dicts
|
|
|
|
|
"""
|
2021-10-28 16:21:23 +00:00
|
|
|
|
config_dict = dict()
|
|
|
|
|
for k, v in config.items():
|
|
|
|
|
if isinstance(v, dict):
|
|
|
|
|
config_dict[k] = unwrap_config(v)
|
|
|
|
|
else:
|
|
|
|
|
config_dict[k] = v
|
|
|
|
|
|
|
|
|
|
return config_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_ranks_name():
|
|
|
|
|
# tensor parallel
|
|
|
|
|
tp_local_rank = 0
|
|
|
|
|
if gpc.is_initialized(ParallelMode.TENSOR):
|
|
|
|
|
tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
|
|
|
|
|
|
|
|
|
# pipeline parallel
|
|
|
|
|
pp_local_rank = 0
|
|
|
|
|
if gpc.is_initialized(ParallelMode.PIPELINE):
|
|
|
|
|
pp_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
|
|
|
|
|
|
|
|
|
ranks_name = f'tp{tp_local_rank}-pp{pp_local_rank}'
|
|
|
|
|
return ranks_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_standard_checkpoint_filename(epoch: int, suffix: str = ''):
|
|
|
|
|
ranks_name = _get_ranks_name()
|
|
|
|
|
return f'epoch{epoch}-{ranks_name}{suffix}.pt'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''):
|
2022-01-21 02:44:30 +00:00
|
|
|
|
"""This is a function to generate the checkpoint path from the (checkpoint_dir, epoch, suffix, gpu_parallel_rank) tuple.
|
2021-10-28 16:21:23 +00:00
|
|
|
|
This is useful during generation and recuperation of the checkpoint.
|
|
|
|
|
|
2022-01-21 02:44:30 +00:00
|
|
|
|
:param checkpoint_dir: Set up a directory for saving checkpoints
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:type checkpoint_dir: str
|
2022-01-21 02:44:30 +00:00
|
|
|
|
:param epoch: Epoch number (indicate how many epochs have you trained this model)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:type epoch: int
|
2022-01-21 02:44:30 +00:00
|
|
|
|
:param suffix: Additional notation to specify the model or checkpoint, defaults to ''
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:type suffix: str, optional
|
2022-01-21 02:44:30 +00:00
|
|
|
|
:return: Checkpoint path to be generated
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:rtype: path
|
2022-01-21 02:44:30 +00:00
|
|
|
|
"""
|
2021-10-28 16:21:23 +00:00
|
|
|
|
ckpt_filename = _get_standard_checkpoint_filename(epoch, suffix)
|
|
|
|
|
return os.path.join(checkpoint_dir, ckpt_filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ensure_directory_exists(filename: str):
|
|
|
|
|
# ensure the directory exists
|
2022-02-14 08:19:24 +00:00
|
|
|
|
dirpath = os.path.dirname(filename)
|
|
|
|
|
if not os.path.exists(dirpath):
|
|
|
|
|
Path(dirpath).mkdir(parents=True, exist_ok=True)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_latest_checkpoint_pattern(suffix: str = ''):
|
2022-01-21 02:44:30 +00:00
|
|
|
|
"""Generate Regular expression of latest checkpoint's pattern
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
2022-01-21 02:44:30 +00:00
|
|
|
|
:param suffix: Additional notation to specify the model or checkpoint, defaults to ''
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:type suffix: str, optional
|
2022-01-21 02:44:30 +00:00
|
|
|
|
:return: Checkpoint pattern
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:rtype: regular expression
|
2022-01-21 02:44:30 +00:00
|
|
|
|
"""
|
2021-10-28 16:21:23 +00:00
|
|
|
|
ranks_name = _get_ranks_name()
|
2022-02-14 08:19:24 +00:00
|
|
|
|
pattern = r'epoch(\d+)-{}{}\.pt'.format(ranks_name, suffix)
|
|
|
|
|
ckpt_pattern = re.compile(pattern)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
return ckpt_pattern
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_latest_checkpoint_path(checkpoint_dir: str, suffix: str = ''):
|
2022-01-21 02:44:30 +00:00
|
|
|
|
"""This is a function to retrieve the latest checkpoint path from the (checkpoint_dir, suffix, gpu_parallel_rank) tuple.
|
2021-10-28 16:21:23 +00:00
|
|
|
|
This is useful during recuperation of the checkpoint, especially when you do not know the epoch number.
|
|
|
|
|
|
2022-01-21 02:44:30 +00:00
|
|
|
|
:param checkpoint_dir: Directory for saving checkpoints
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:type checkpoint_dir: str
|
2022-01-21 02:44:30 +00:00
|
|
|
|
:param suffix: Additional notation to specify the model or checkpoint, defaults to ''
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:type suffix: str, optional
|
2022-01-21 02:44:30 +00:00
|
|
|
|
:raises FileNotFoundError: Raise error when we cannot find the latest checkpoint file with inputs given
|
|
|
|
|
:return: The latest checkpoint path to be retrieved
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:rtype: path
|
2022-01-21 02:44:30 +00:00
|
|
|
|
"""
|
2021-10-28 16:21:23 +00:00
|
|
|
|
CKPT_NAME_PAT = get_latest_checkpoint_pattern(suffix=suffix)
|
|
|
|
|
|
|
|
|
|
last_epoch = -1
|
|
|
|
|
assert osp.isdir(checkpoint_dir), f'{checkpoint_dir} is not a directory'
|
|
|
|
|
|
|
|
|
|
for filename in os.listdir(checkpoint_dir):
|
|
|
|
|
ret = CKPT_NAME_PAT.match(filename)
|
|
|
|
|
if ret:
|
|
|
|
|
epoch = int(ret[0].split('-')[0].lstrip('epoch'))
|
|
|
|
|
if epoch > last_epoch:
|
|
|
|
|
last_epoch = epoch
|
|
|
|
|
|
|
|
|
|
if last_epoch == -1:
|
|
|
|
|
ranks_name = _get_ranks_name()
|
|
|
|
|
raise FileNotFoundError(f"Cannot find the latest checkpoint file for {ranks_name} in {checkpoint_dir}")
|
|
|
|
|
else:
|
|
|
|
|
target_file = _get_standard_checkpoint_filename(last_epoch, suffix=suffix)
|
|
|
|
|
path = osp.join(checkpoint_dir, target_file)
|
|
|
|
|
return path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(checkpoint_path: str,
|
|
|
|
|
epoch: int,
|
|
|
|
|
model: torch.nn.Module,
|
|
|
|
|
optimizer: torch.optim.Optimizer,
|
|
|
|
|
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
|
|
|
|
**kwargs):
|
2022-02-14 08:19:24 +00:00
|
|
|
|
"""Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model,
|
|
|
|
|
optimizer, lr_scheduler and etc. into a checkpoint dictionary.
|
2022-01-21 02:44:30 +00:00
|
|
|
|
|
2021-10-28 16:21:23 +00:00
|
|
|
|
This method can be used for both colosalai nn.BaseModel and normal pytorch nn.Module.
|
|
|
|
|
|
|
|
|
|
|
2022-01-21 02:44:30 +00:00
|
|
|
|
:param checkpoint_path: Set up a directory for saving checkpoints
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:type checkpoint_path: str
|
2022-01-21 02:44:30 +00:00
|
|
|
|
:param epoch: Epoch number (indicate how many epochs have you trained this model)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:type epoch: int
|
2022-01-21 02:44:30 +00:00
|
|
|
|
:param model: Model to be registered
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:type model: torch.nn.Module
|
2022-01-21 02:44:30 +00:00
|
|
|
|
:param optimizer: Optimizer to be registered
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:type optimizer: torch.optim.Optimizer
|
|
|
|
|
:param lr_scheduler: lr_scheduler to be registered, defaults to None
|
|
|
|
|
:type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional
|
2022-01-21 02:44:30 +00:00
|
|
|
|
"""
|
2021-10-28 16:21:23 +00:00
|
|
|
|
# for compatibility with normal pytorch nn.Module
|
|
|
|
|
if hasattr(model, 'state_dict_for_save_checkpoint'):
|
|
|
|
|
model_sd = model.state_dict_for_save_checkpoint()
|
|
|
|
|
else:
|
|
|
|
|
model_sd = model.state_dict()
|
|
|
|
|
|
|
|
|
|
# ckpt container
|
2022-02-14 08:19:24 +00:00
|
|
|
|
checkpoint = {'epoch': epoch, 'model': model_sd, 'optimizer': optimizer.state_dict(), **kwargs}
|
2021-10-28 16:21:23 +00:00
|
|
|
|
if lr_scheduler is not None:
|
|
|
|
|
checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
|
|
|
|
|
|
|
|
|
|
_ensure_directory_exists(checkpoint_path)
|
|
|
|
|
torch.save(checkpoint, checkpoint_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(checkpoint_path: str,
|
|
|
|
|
model: torch.nn.Module,
|
|
|
|
|
optimizer: torch.optim.Optimizer,
|
|
|
|
|
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
|
|
|
|
finetune: bool = False,
|
|
|
|
|
strict: bool = True) -> Tuple:
|
2022-01-21 02:44:30 +00:00
|
|
|
|
"""Loads the checkpoint file.
|
|
|
|
|
If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
|
2022-02-14 08:19:24 +00:00
|
|
|
|
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler)
|
|
|
|
|
and its descendants.
|
2021-10-28 16:21:23 +00:00
|
|
|
|
If finetune is True, then only the weights and buffers of model should be reload.
|
2022-02-14 08:19:24 +00:00
|
|
|
|
If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s
|
|
|
|
|
state_dict() function.
|
2022-01-21 02:44:30 +00:00
|
|
|
|
|
|
|
|
|
:param checkpoint_path: The exact and matched checkpoint_path directory to retrieve appropriate state_dict
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:type checkpoint_path: str
|
2022-01-21 02:44:30 +00:00
|
|
|
|
:param model: Model to reload parameters and buffers
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:type model: torch.nn.Module
|
2022-01-21 02:44:30 +00:00
|
|
|
|
:param optimizer: Optimizer to recuperate
|
|
|
|
|
:type optimizer: torch.optim.Optimizer
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:param lr_scheduler: lr_scheduler to recuperate, defaults to None
|
|
|
|
|
:type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional
|
2022-01-21 02:44:30 +00:00
|
|
|
|
:param finetune: Whether to finetune the model with new dataset or continue the pre-training, defaults to False
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:type finetune: bool, optional
|
2022-01-21 02:44:30 +00:00
|
|
|
|
:param strict: Whether to strictly enforce that the keys in
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:attr:`state_dict` of the checkpoint match the names of
|
|
|
|
|
parameters and buffers in model., defaults to True
|
|
|
|
|
:type strict: bool, optional
|
2022-01-21 02:44:30 +00:00
|
|
|
|
:raises ValueError: Raise error if the model/optimizer cannot successfully be recuperated
|
2021-10-28 16:21:23 +00:00
|
|
|
|
:return: (the epoch number of the checkpoint retrieved, the checkpoint retrieved)
|
|
|
|
|
:rtype: Tuple
|
|
|
|
|
|
2022-01-21 02:44:30 +00:00
|
|
|
|
"""
|
2021-10-28 16:21:23 +00:00
|
|
|
|
# Load the checkpoint.
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
|
try:
|
|
|
|
|
last_epoch = checkpoint.pop('epoch') if not finetune else 0
|
|
|
|
|
model.load_state_dict(checkpoint.pop('model'), strict=strict)
|
|
|
|
|
except KeyError:
|
|
|
|
|
raise ValueError('Checkpoint is corrupted')
|
|
|
|
|
|
|
|
|
|
if not finetune:
|
|
|
|
|
try:
|
|
|
|
|
optimizer.load_state_dict(checkpoint.pop('optimizer'))
|
|
|
|
|
except KeyError:
|
|
|
|
|
raise ValueError('Checkpoint is corrupted')
|
|
|
|
|
|
|
|
|
|
if lr_scheduler is not None and 'lr_scheduler' in checkpoint:
|
|
|
|
|
lr_scheduler.load_state_dict(checkpoint.pop('lr_scheduler'))
|
|
|
|
|
|
|
|
|
|
return last_epoch, checkpoint
|