import os
import os . path as osp
import re
from typing import Tuple
import torch
from colossalai . context import Config
from colossalai . context . parallel_mode import ParallelMode
from colossalai . core import global_context as gpc
__all__ = [
' get_checkpoint_path ' ,
' get_latest_checkpoint_path ' ,
' get_latest_checkpoint_pattern ' ,
' save_checkpoint ' ,
' load_checkpoint '
]
def unwrap_config ( config : Config ) :
""" Unwrap Config objects to normal dicts
"""
config_dict = dict ( )
for k , v in config . items ( ) :
if isinstance ( v , dict ) :
config_dict [ k ] = unwrap_config ( v )
else :
config_dict [ k ] = v
return config_dict
def _get_ranks_name ( ) :
# tensor parallel
tp_local_rank = 0
if gpc . is_initialized ( ParallelMode . TENSOR ) :
tp_local_rank = gpc . get_local_rank ( ParallelMode . TENSOR )
# pipeline parallel
pp_local_rank = 0
if gpc . is_initialized ( ParallelMode . PIPELINE ) :
pp_local_rank = gpc . get_local_rank ( ParallelMode . PIPELINE )
ranks_name = f ' tp { tp_local_rank } -pp { pp_local_rank } '
return ranks_name
def _get_standard_checkpoint_filename ( epoch : int , suffix : str = ' ' ) :
ranks_name = _get_ranks_name ( )
return f ' epoch { epoch } - { ranks_name } { suffix } .pt '
def get_checkpoint_path ( checkpoint_dir : str , epoch : int , suffix : str = ' ' ) :
""" This is a function to generate the checkpoint path from the (checkpoint_dir, epoch, suffix, gpu_parallel_rank) tuple.
This is useful during generation and recuperation of the checkpoint .
: param checkpoint_dir : Set up a directory for saving checkpoints
: type checkpoint_dir : str
: param epoch : Epoch number ( indicate how many epochs have you trained this model )
: type epoch : int
: param suffix : Additional notation to specify the model or checkpoint , defaults to ' '
: type suffix : str , optional
: return : Checkpoint path to be generated
: rtype : path
"""
ckpt_filename = _get_standard_checkpoint_filename ( epoch , suffix )
return os . path . join ( checkpoint_dir , ckpt_filename )
def _ensure_directory_exists ( filename : str ) :
# ensure the directory exists
dir = os . path . dirname ( filename )
if not os . path . exists ( dir ) :
os . makedirs ( dir )
def get_latest_checkpoint_pattern ( suffix : str = ' ' ) :
""" Generate Regular expression of latest checkpoint ' s pattern
: param suffix : Additional notation to specify the model or checkpoint , defaults to ' '
: type suffix : str , optional
: return : Checkpoint pattern
: rtype : regular expression
"""
ranks_name = _get_ranks_name ( )
ckpt_pattern = re . compile ( f ' epoch( \ d+)- { ranks_name } { suffix } \ .pt ' )
return ckpt_pattern
def get_latest_checkpoint_path ( checkpoint_dir : str , suffix : str = ' ' ) :
""" This is a function to retrieve the latest checkpoint path from the (checkpoint_dir, suffix, gpu_parallel_rank) tuple.
This is useful during recuperation of the checkpoint , especially when you do not know the epoch number .
: param checkpoint_dir : Directory for saving checkpoints
: type checkpoint_dir : str
: param suffix : Additional notation to specify the model or checkpoint , defaults to ' '
: type suffix : str , optional
: raises FileNotFoundError : Raise error when we cannot find the latest checkpoint file with inputs given
: return : The latest checkpoint path to be retrieved
: rtype : path
"""
CKPT_NAME_PAT = get_latest_checkpoint_pattern ( suffix = suffix )
last_epoch = - 1
assert osp . isdir ( checkpoint_dir ) , f ' { checkpoint_dir } is not a directory '
for filename in os . listdir ( checkpoint_dir ) :
ret = CKPT_NAME_PAT . match ( filename )
if ret :
epoch = int ( ret [ 0 ] . split ( ' - ' ) [ 0 ] . lstrip ( ' epoch ' ) )
if epoch > last_epoch :
last_epoch = epoch
if last_epoch == - 1 :
ranks_name = _get_ranks_name ( )
raise FileNotFoundError ( f " Cannot find the latest checkpoint file for { ranks_name } in { checkpoint_dir } " )
else :
target_file = _get_standard_checkpoint_filename ( last_epoch , suffix = suffix )
path = osp . join ( checkpoint_dir , target_file )
return path
def save_checkpoint ( checkpoint_path : str ,
epoch : int ,
model : torch . nn . Module ,
optimizer : torch . optim . Optimizer ,
lr_scheduler : torch . optim . lr_scheduler . _LRScheduler = None ,
* * kwargs ) :
""" Given a directory to store the checkpoints, saves all the training components ' parameters or buffers, such as model, optimizer, lr_scheduler and etc. into a checkpoint dictionary.
This method can be used for both colosalai nn . BaseModel and normal pytorch nn . Module .
: param checkpoint_path : Set up a directory for saving checkpoints
: type checkpoint_path : str
: param epoch : Epoch number ( indicate how many epochs have you trained this model )
: type epoch : int
: param model : Model to be registered
: type model : torch . nn . Module
: param optimizer : Optimizer to be registered
: type optimizer : torch . optim . Optimizer
: param lr_scheduler : lr_scheduler to be registered , defaults to None
: type lr_scheduler : torch . optim . lr_scheduler . _LRScheduler , optional
"""
# for compatibility with normal pytorch nn.Module
if hasattr ( model , ' state_dict_for_save_checkpoint ' ) :
model_sd = model . state_dict_for_save_checkpoint ( )
else :
model_sd = model . state_dict ( )
# ckpt container
checkpoint = {
' epoch ' : epoch ,
' model ' : model_sd ,
' optimizer ' : optimizer . state_dict ( ) ,
* * kwargs
}
if lr_scheduler is not None :
checkpoint [ ' lr_scheduler ' ] = lr_scheduler . state_dict ( )
_ensure_directory_exists ( checkpoint_path )
torch . save ( checkpoint , checkpoint_path )
def load_checkpoint ( checkpoint_path : str ,
model : torch . nn . Module ,
optimizer : torch . optim . Optimizer ,
lr_scheduler : torch . optim . lr_scheduler . _LRScheduler = None ,
finetune : bool = False ,
strict : bool = True ) - > Tuple :
""" Loads the checkpoint file.
If finetune is False , then we intend to continue / resume the training process from the checkpoint given .
So we copy parameters and buffers from state_dict into these modules ( model , optimizer , lr_scheduler ) and its descendants .
If finetune is True , then only the weights and buffers of model should be reload .
If strict is True , then the keys of state_dict must exactly match the keys returned by this module ’ s state_dict ( ) function .
: param checkpoint_path : The exact and matched checkpoint_path directory to retrieve appropriate state_dict
: type checkpoint_path : str
: param model : Model to reload parameters and buffers
: type model : torch . nn . Module
: param optimizer : Optimizer to recuperate
: type optimizer : torch . optim . Optimizer
: param lr_scheduler : lr_scheduler to recuperate , defaults to None
: type lr_scheduler : torch . optim . lr_scheduler . _LRScheduler , optional
: param finetune : Whether to finetune the model with new dataset or continue the pre - training , defaults to False
: type finetune : bool , optional
: param strict : Whether to strictly enforce that the keys in
: attr : ` state_dict ` of the checkpoint match the names of
parameters and buffers in model . , defaults to True
: type strict : bool , optional
: raises ValueError : Raise error if the model / optimizer cannot successfully be recuperated
: return : ( the epoch number of the checkpoint retrieved , the checkpoint retrieved )
: rtype : Tuple
"""
# Load the checkpoint.
checkpoint = torch . load ( checkpoint_path , map_location = ' cpu ' )
try :
last_epoch = checkpoint . pop ( ' epoch ' ) if not finetune else 0
model . load_state_dict ( checkpoint . pop ( ' model ' ) , strict = strict )
except KeyError :
raise ValueError ( ' Checkpoint is corrupted ' )
if not finetune :
try :
optimizer . load_state_dict ( checkpoint . pop ( ' optimizer ' ) )
except KeyError :
raise ValueError ( ' Checkpoint is corrupted ' )
if lr_scheduler is not None and ' lr_scheduler ' in checkpoint :
lr_scheduler . load_state_dict ( checkpoint . pop ( ' lr_scheduler ' ) )
return last_epoch , checkpoint