2023-04-06 08:23:39 +00:00
# coding=utf-8
2023-07-21 06:39:01 +00:00
import os
2023-05-18 12:05:59 +00:00
import re
2023-06-15 07:21:26 +00:00
from collections import abc as container_abcs
from collections import defaultdict
from itertools import chain
2023-04-04 07:23:01 +00:00
from pathlib import Path
2023-05-18 12:05:59 +00:00
from typing import Iterator , List , Mapping , Optional , OrderedDict , Tuple
2023-04-04 07:23:01 +00:00
import torch
2023-04-06 08:23:39 +00:00
import torch . nn as nn
2023-10-07 02:45:52 +00:00
from packaging . version import Version
2023-06-15 07:21:26 +00:00
from torch . optim import Optimizer
2023-11-09 09:00:25 +00:00
from torch . utils . _pytree import tree_map
2023-05-18 12:05:59 +00:00
2023-08-25 14:04:57 +00:00
from colossalai . tensor . d_tensor import (
is_customized_distributed_tensor ,
is_distributed_tensor ,
to_global ,
to_global_for_customized_distributed_tensor ,
)
2023-04-06 08:23:39 +00:00
SAFE_WEIGHTS_NAME = " model.safetensors "
2023-04-12 08:02:17 +00:00
WEIGHTS_NAME = " pytorch_model.bin "
2023-06-15 07:21:26 +00:00
STATES_NAME = " pytorch_optim.bin "
2023-04-06 08:23:39 +00:00
SAFE_WEIGHTS_INDEX_NAME = " model.safetensors.index.json "
2023-04-12 08:02:17 +00:00
WEIGHTS_INDEX_NAME = " pytorch_model.bin.index.json "
2023-06-15 07:21:26 +00:00
STATES_INDEX_NAME = " pytorch_optim.bin.index.json "
GROUP_FILE_NAME = " pytorch_optim_group.bin "
2023-04-04 07:23:01 +00:00
# ======================================
# General helper functions
# ======================================
2023-05-18 12:05:59 +00:00
2023-04-04 07:23:01 +00:00
def calculate_tensor_size ( tensor : torch . Tensor ) - > float :
"""
Calculate the size of a parameter in MB . Used to compute whether a group of params exceed the shard size .
If so , a new shard should be created .
Args :
2023-05-15 03:46:25 +00:00
tensor ( torch . Tensor ) : the tensor to calculate size for .
2023-04-04 07:23:01 +00:00
Returns :
float : size of the tensor in MB .
"""
return tensor . numel ( ) * tensor . element_size ( ) / 1024 / 1024
2023-05-18 12:05:59 +00:00
2023-04-04 07:23:01 +00:00
def is_safetensors_available ( ) - > bool :
"""
Check whether safetensors is available .
Returns :
bool : whether safetensors is available .
"""
try :
return True
except ImportError :
return False
def is_dtensor_checkpoint ( checkpoint_file_path : str ) - > bool :
"""
Check whether the checkpoint file is a dtensor checkpoint .
Args :
checkpoint_file_path ( str ) : path to the checkpoint file .
Returns :
bool : whether the checkpoint file is a dtensor checkpoint .
"""
2023-09-19 06:20:26 +00:00
if checkpoint_file_path . endswith ( " .*.safetensors " ) or checkpoint_file_path . endswith ( " .*.bin " ) :
2023-04-04 07:23:01 +00:00
return True
else :
return False
def is_safetensor_checkpoint ( checkpoint_file_path : str ) - > bool :
"""
Check whether the checkpoint file is a safetensor checkpoint .
Args :
checkpoint_file_path ( str ) : path to the checkpoint file .
Returns :
bool : whether the checkpoint file is a safetensor checkpoint .
"""
2023-09-19 06:20:26 +00:00
if checkpoint_file_path . endswith ( " .safetensors " ) :
2023-04-04 07:23:01 +00:00
return True
else :
return False
2023-08-31 06:50:47 +00:00
def search_tp_partition_dim ( current_shape : torch . Size , original_shape : torch . Size , tp_size : int ) - > Optional [ int ] :
2023-08-25 14:04:57 +00:00
"""
2023-08-31 06:50:47 +00:00
Given the current shape of parameter and the shape of parameter before sharding ,
return the dimension along which the parameter is sharded when using tensor parallel .
If tensor parallel is not used , return None .
2023-08-25 14:04:57 +00:00
Args :
2023-08-31 06:50:47 +00:00
current_shape ( torch . Size ) : The current shape of parameter after sharding .
original_shape ( torch . Size ) : The shape of parameter before sharding .
tp_size ( int ) : The size of tp group .
2023-08-25 14:04:57 +00:00
Returns :
2023-08-31 06:50:47 +00:00
Optional [ int ] : The dimension along which parameter is partitioned .
2023-08-25 14:04:57 +00:00
"""
2023-08-31 06:50:47 +00:00
partition_dim = None
for dim , length in enumerate ( original_shape ) :
if length > current_shape [ dim ] :
partition_dim = dim
break
if partition_dim is not None :
2023-09-19 06:20:26 +00:00
assert (
original_shape [ partition_dim ] == tp_size * current_shape [ partition_dim ]
) , f " The parameter isn ' t evenly distributed among tensor parallel group: \
2023-08-31 06:50:47 +00:00
shape before sharding { original_shape } , shape after sharding { current_shape } "
return partition_dim
2023-08-25 14:04:57 +00:00
2023-04-06 08:23:39 +00:00
# ======================================
2023-08-25 14:04:57 +00:00
# Helper classes and functions for saving shard file
2023-04-06 08:23:39 +00:00
# ======================================
2023-07-07 08:33:06 +00:00
2023-08-25 14:04:57 +00:00
class StateDictSharder :
def __init__ ( self , size_per_shard : int ) - > None :
self . max_shard_size = size_per_shard
self . current_block = OrderedDict ( )
self . current_block_size = 0
2023-08-31 06:50:47 +00:00
def append_param ( self , name : str , tensor : torch . Tensor ) - > Tuple [ Optional [ OrderedDict ] , int ] :
2023-08-25 14:04:57 +00:00
tensor_size = calculate_tensor_size ( tensor )
ret_block = None
ret_block_size = 0
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if self . current_block_size + tensor_size > self . max_shard_size and self . current_block_size > 0 :
ret_block = self . current_block
ret_block_size = self . current_block_size
self . current_block = OrderedDict ( )
self . current_block_size = 0
self . current_block [ name ] = tensor
self . current_block_size + = tensor_size
return ret_block , ret_block_size
2023-08-31 06:50:47 +00:00
def append_optim_state ( self , param_id : int , state : OrderedDict ) - > Tuple [ Optional [ OrderedDict ] , int ] :
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size = 0
isDTensor = False
for state_tensor in state . values ( ) :
# When state_tensor is not of Tensor class,
# e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error.
if not isinstance ( state_tensor , torch . Tensor ) :
continue
# If the states are stored as DTensors, mark isDTensor as true.
if is_distributed_tensor ( state_tensor ) :
isDTensor = True
state_size + = calculate_tensor_size ( state_tensor )
ret_block = None
ret_block_size = 0
# directly return if state is stored as distributed tensor
if isDTensor :
return ret_block , ret_block_size
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if self . current_block_size + state_size > self . max_shard_size and self . current_block_size > 0 :
ret_block = self . current_block
ret_block_size = self . current_block_size
self . current_block = OrderedDict ( )
self . current_block_size = 0
self . current_block [ param_id ] = state
self . current_block_size + = state_size
return ret_block , ret_block_size
def gather_distributed_param ( param : torch . Tensor , keep_vars : bool = False ) - > torch . Tensor :
"""
Gather the complete parameter for saving if passed in param is distributed under tp setting .
Args :
param ( torch . Tensor ) : A model parameter , might be d_tensor .
keep_vars ( bool , optional ) : Whether to return the parameter in calculation graph . Defaults to False .
Returns :
torch . Tensor : the complete parameter
"""
param_ = param if keep_vars else param . detach ( )
if is_distributed_tensor ( param_ ) :
return to_global ( param_ )
elif is_customized_distributed_tensor ( param_ ) :
return to_global_for_customized_distributed_tensor ( param_ )
else :
return param_
2023-08-25 14:04:57 +00:00
2023-09-19 06:20:26 +00:00
def save_state_dict_shards (
sharded_state_dict : Iterator [ Tuple [ OrderedDict , int ] ] ,
checkpoint : str ,
index_file : " CheckpointIndexFile " ,
base_filename : str ,
is_master : bool ,
use_safetensors : bool = False ,
use_pp_format : bool = False ,
) - > int :
"""
2023-07-21 06:39:01 +00:00
Save sharded state dict only on master rank , this method can be used by both model and optimizer states .
Args :
sharded_state_dict ( Iterator [ Tuple [ OrderedDict , int ] ] ) : a generator of shards , each shard contains state dict and shard size .
checkpoint ( str ) : The path of checkpoint directory as string .
index_file ( CheckpointIndexFile ) : The index file object to be updated .
base_filename ( str ) : Decides the prefix of filenames of shards .
2023-09-01 09:40:01 +00:00
is_master ( bool ) : Whether current rank is main process .
use_safetensors ( bool , optional ) : Whether to use safetensors to save checkpoint . Defaults to False .
use_pp_format : ( bool , optional ) : Whether to save the files in pipeline format including stage information . Defaults to False .
2023-07-21 06:39:01 +00:00
Returns :
int : the total size of shards
2023-09-19 06:20:26 +00:00
"""
2023-07-21 06:39:01 +00:00
total_size = 0
2023-09-01 09:40:01 +00:00
shard_filenames = [ ]
2023-07-21 06:39:01 +00:00
for idx , shard_pair in enumerate ( sharded_state_dict ) :
2023-08-25 14:04:57 +00:00
shard , current_size = shard_pair
2023-07-21 06:39:01 +00:00
if not is_master :
2023-08-25 14:04:57 +00:00
del shard
2023-07-21 06:39:01 +00:00
continue
shard_file = get_shard_filename ( base_filename , idx )
total_size = total_size + current_size
for key in shard . keys ( ) :
index_file . append_weight_map ( key , shard_file )
checkpoint_file_path = os . path . join ( checkpoint , shard_file )
# Only save on master rank.
save_state_dict ( shard , checkpoint_file_path , use_safetensors = use_safetensors )
2023-09-01 09:40:01 +00:00
shard_filenames . append ( shard_file )
2023-08-25 14:04:57 +00:00
del shard
2023-07-21 06:39:01 +00:00
2023-09-01 09:40:01 +00:00
# Clean folder, deleted unneeded files.
clean_folder ( checkpoint , base_filename , shard_filenames , is_master = is_master , use_pp_format = use_pp_format )
2023-07-21 06:39:01 +00:00
return total_size
2023-06-15 07:21:26 +00:00
def shard_model_checkpoint ( state_dict : torch . Tensor , max_shard_size : int = 1024 ) - > Iterator [ Tuple [ OrderedDict , int ] ] :
2023-04-06 08:23:39 +00:00
"""
Splits a model state dictionary in sub - checkpoints so that the final size of each sub - checkpoint does not exceed a
given size .
"""
2023-08-31 06:50:47 +00:00
state_dict_sharder = StateDictSharder ( max_shard_size )
2023-04-06 08:23:39 +00:00
for key , weight in state_dict . items ( ) :
2023-06-26 07:50:07 +00:00
if not is_distributed_tensor ( weight ) :
2023-08-31 06:50:47 +00:00
block , block_size = state_dict_sharder . append_param ( key , weight )
2023-05-18 12:05:59 +00:00
2023-08-31 06:50:47 +00:00
if block != None :
yield block , block_size
2023-04-06 08:23:39 +00:00
2023-08-31 06:50:47 +00:00
# Return the last block in sharder.
yield state_dict_sharder . current_block , state_dict_sharder . current_block_size
2023-04-06 08:23:39 +00:00
2023-06-15 07:21:26 +00:00
def shard_optimizer_checkpoint ( state_dict : dict , max_shard_size : int = 1024 ) - > Iterator [ Tuple [ OrderedDict , int ] ] :
"""
Splits an optimizer state dictionary in sub - checkpoints so that the final size of each sub - checkpoint does not exceed a
given size .
"""
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
2023-09-19 06:20:26 +00:00
states = state_dict [ " state " ]
2023-08-31 06:50:47 +00:00
state_dict_sharder = StateDictSharder ( max_shard_size )
2023-06-15 07:21:26 +00:00
for param_id , state in states . items ( ) :
2023-08-31 06:50:47 +00:00
block , block_size = state_dict_sharder . append_optim_state ( param_id , state )
if block != None :
yield block , block_size
2023-06-15 07:21:26 +00:00
2023-08-31 06:50:47 +00:00
# Return the last block in sharder.
yield state_dict_sharder . current_block , state_dict_sharder . current_block_size
2023-06-15 07:21:26 +00:00
2023-06-16 06:14:05 +00:00
2023-08-31 06:50:47 +00:00
# ======================================
# Helper functions for saving state dict
# ======================================
2023-06-16 06:14:05 +00:00
2024-04-08 07:09:40 +00:00
2023-08-31 06:50:47 +00:00
def save_state_dict ( state_dict : dict , checkpoint_file_path : str , use_safetensors : bool ) - > None :
"""
Save state dict to checkpoint .
Args :
state_dict ( dict ) : state dict .
checkpoint_file_path ( str ) : path to the checkpoint file .
use_safetensors ( bool ) : whether to use safetensors to save the checkpoint .
"""
2023-11-09 09:00:25 +00:00
# Move all tensors in the state_dict to CPU before saving to avoid serialization issues
state_dict_cpu = tree_map ( lambda x : x . cpu ( ) if torch . is_tensor ( x ) else x , state_dict )
2024-04-08 07:09:40 +00:00
2023-08-31 06:50:47 +00:00
if use_safetensors :
assert is_safetensors_available ( ) , " safetensors is not available. "
2023-09-19 06:20:26 +00:00
assert checkpoint_file_path . endswith (
" .safetensors "
) , " safetensors only supports .safetensors suffix for checkpoint file. "
2023-08-31 06:50:47 +00:00
from safetensors . torch import save_file as safe_save_file
2023-09-19 06:20:26 +00:00
2023-11-09 09:00:25 +00:00
safe_save_file ( state_dict_cpu , checkpoint_file_path , metadata = { " format " : " pt " } )
2023-08-31 06:50:47 +00:00
else :
2023-11-09 09:00:25 +00:00
torch . save ( state_dict_cpu , checkpoint_file_path )
2023-08-31 06:50:47 +00:00
def save_param_groups ( state_dict : dict , group_file_path : str ) - > None :
"""
Save information of param_groups to given file path .
Args :
state_dict ( dict ) : state dict .
group_file_path ( str ) : path to the group file .
"""
param_groups = state_dict [ " param_groups " ]
torch . save ( param_groups , group_file_path )
2023-09-19 06:20:26 +00:00
def clean_folder (
checkpoint_path : str ,
weights_name : str ,
shard_filenames : List [ str ] ,
is_master : bool = True ,
use_pp_format : bool = False ,
) :
2023-09-01 09:40:01 +00:00
"""
Clean the unneeded files in checkpoint directory after shards of state_dict have been saved .
Args :
checkpoint_path ( str ) : Path to the checkpoint directory .
weights_name ( str ) : Decides the prefix of filenames of weight shards .
shard_filenames ( List [ str ] ) : The list of saved shard filenames which should not be removed .
is_master ( bool , optional ) : Whether current rank is main process . Defaults to True .
use_pp_format : ( bool , optional ) : Whether to save the files in pipeline format including stage information . Defaults to False .
"""
if is_master :
for filename in os . listdir ( checkpoint_path ) :
full_filename = os . path . join ( checkpoint_path , filename )
weights_no_suffix = weights_name . replace ( " .bin " , " " ) . replace ( " .safetensors " , " " )
filename_no_suffix = filename . replace ( " .bin " , " " ) . replace ( " .safetensors " , " " )
if not use_pp_format :
reg = re . compile ( r " (.*?)- \ d {5} " )
else :
# When this checkpoint is created by pipeline parallel process, the pattern is a little different.
reg = re . compile ( r " (.*?)-stage- \ d {5} -shard- \ d {5} " )
2023-09-19 06:20:26 +00:00
if (
filename . startswith ( weights_no_suffix )
and os . path . isfile ( full_filename )
and filename not in shard_filenames
and reg . fullmatch ( filename_no_suffix ) is not None
) :
2023-09-01 09:40:01 +00:00
os . remove ( full_filename )
def save_config_file ( model : nn . Module , checkpoint_path : str , is_master : bool = True ) :
"""
Save config . json / generation_config . json if model is a Huggingface pretrained model .
This method can only be called when a model is saved in a sharded way .
Args :
model ( nn . Module ) : The model whose config should be saved if it ' s a huggingface model.
checkpoint_path ( str ) : Path to the checkpoint directory .
is_master ( bool ) : Whether current rank is main process .
"""
2023-09-11 08:24:28 +00:00
try :
from transformers . modeling_utils import PreTrainedModel , get_parameter_dtype
from transformers . modeling_utils import unwrap_model as unwrap_huggingface_model
except ImportError :
return
2023-09-01 09:40:01 +00:00
if not isinstance ( model , PreTrainedModel ) :
return
model = unwrap_huggingface_model ( model )
# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
dtype = get_parameter_dtype ( model )
model . config . torch_dtype = str ( dtype ) . split ( " . " ) [ 1 ]
# Attach architecture to the config
model . config . architectures = [ model . __class__ . __name__ ]
# Save the config
if is_master :
model . config . save_pretrained ( checkpoint_path )
if model . can_generate ( ) :
model . generation_config . save_pretrained ( checkpoint_path )
2023-08-31 06:50:47 +00:00
def save_dtensor ( name : str , tensor : torch . Tensor , index_file : " CheckpointIndexFile " , use_safetensors : bool ) - > None :
"""
Save distributed tensor to checkpoint . This checkpoint will be a dictionary which contains
only one tensor .
Args :
tensor ( Tensor ) : tensor to be saved .
index_file ( CheckpointIndexFile ) : path to the checkpoint file .
size_per_shard ( int ) : size per shard in MB .
"""
root_path = index_file . root_path
2023-09-19 06:20:26 +00:00
output_root_path = root_path . joinpath ( " dtensor " )
2023-08-31 06:50:47 +00:00
# create directory
output_root_path . mkdir ( exist_ok = True )
# save tensor to this directory
# TODO(YuliangLiu): get index of the tensor shard
# e.g. index =
index = 0
# save tensor to file
ckpt_file_name = generate_dtensor_file_name ( name , index , use_safetensors )
ckpt_file_path = output_root_path . joinpath ( ckpt_file_name )
# dtensor ckpt file always contains only one tensor
state_dict = { name : tensor }
save_state_dict ( state_dict , str ( ckpt_file_path ) , use_safetensors )
# update the weight map
# * means all shards
2023-09-19 06:20:26 +00:00
ckpt_file_name_in_weight_map = " dtensor/ " + generate_dtensor_file_name ( name , " * " , use_safetensors )
2023-08-31 06:50:47 +00:00
index_file . append_weight_map ( name , ckpt_file_name_in_weight_map )
def get_checkpoint_file_suffix ( use_safetensors : bool ) - > str :
"""
Get checkpoint file suffix .
Args :
use_safetensors ( bool ) : whether to use safetensors to save the checkpoint .
Returns :
str : checkpoint file suffix .
"""
if use_safetensors :
2023-09-19 06:20:26 +00:00
return " .safetensors "
2023-08-31 06:50:47 +00:00
else :
2023-09-19 06:20:26 +00:00
return " .bin "
2023-08-31 06:50:47 +00:00
2023-09-19 06:20:26 +00:00
def generate_checkpoint_shard_file_name (
index : int , total_number : int , use_safetensors : bool , prefix : str = None
) - > str :
2023-08-31 06:50:47 +00:00
"""
Generate checkpoint shard file name .
Args :
index ( int ) : index of the shard .
total_number ( int ) : total number of shards .
use_safetensors ( bool ) : whether to use safetensors to save the checkpoint .
prefix ( str ) : prefix of the shard file name . Default : None .
Returns :
str : checkpoint shard file name .
"""
suffix = get_checkpoint_file_suffix ( use_safetensors )
if prefix is None :
return f " { index : 05d } -of- { total_number : 05d } . { suffix } "
else :
return f " { prefix } - { index : 05d } -of- { total_number : 05d } . { suffix } "
2023-06-15 07:21:26 +00:00
2023-08-31 06:50:47 +00:00
def generate_dtensor_file_name ( param_name : str , index : int , use_safetensors : bool ) - > str :
"""
Generate dtensor file name .
Args :
param_name ( str ) : name of the distributed parameter .
index ( int ) : index of the shard .
use_safetensors ( bool ) : whether to use safetensors to save the checkpoint .
2023-06-15 07:21:26 +00:00
2023-08-31 06:50:47 +00:00
Returns :
str : dtensor file name .
"""
suffix = get_checkpoint_file_suffix ( use_safetensors )
2023-09-19 06:20:26 +00:00
return f " { param_name } . { index } . { suffix } "
2023-06-15 07:21:26 +00:00
2023-08-31 06:50:47 +00:00
# ========================================
# Helper functions for loading state dict
# ========================================
2023-06-15 07:21:26 +00:00
2023-05-18 12:05:59 +00:00
def load_shard_state_dict ( checkpoint_file : Path , use_safetensors : bool = False ) :
2023-04-06 08:23:39 +00:00
"""
load shard state dict into model
"""
if use_safetensors and not checkpoint_file . suffix == " .safetensors " :
raise Exception ( " load the model using `safetensors`, but no file endwith .safetensors " )
if use_safetensors :
from safetensors . torch import load_file as safe_load_file
2023-05-18 12:05:59 +00:00
from safetensors . torch import safe_open
2023-09-19 06:20:26 +00:00
2023-04-06 08:23:39 +00:00
with safe_open ( checkpoint_file , framework = " pt " ) as f :
metadata = f . metadata ( )
if metadata [ " format " ] != " pt " :
raise NotImplementedError (
2023-09-19 06:20:26 +00:00
f " Conversion from a { metadata [ ' format ' ] } safetensors archive to PyTorch is not implemented yet. "
)
2023-04-06 08:23:39 +00:00
return safe_load_file ( checkpoint_file )
else :
2023-09-19 06:20:26 +00:00
return torch . load ( checkpoint_file , map_location = torch . device ( " cpu " ) )
2023-05-18 12:05:59 +00:00
2023-09-19 06:20:26 +00:00
def load_state_dict_into_model (
model : nn . Module , state_dict : torch . Tensor , missing_keys : List , strict : bool = False , load_sub_module : bool = True
) :
2023-04-06 08:23:39 +00:00
r """ Copies parameters and buffers from :attr:`state_dict` into
2023-05-18 12:05:59 +00:00
this module and its descendants .
2023-04-06 08:23:39 +00:00
Args :
state_dict ( dict ) : a dict containing parameters and
persistent buffers .
"""
if not isinstance ( state_dict , Mapping ) :
raise TypeError ( " Expected state_dict to be dict-like, got {} . " . format ( type ( state_dict ) ) )
unexpected_keys : List [ str ] = [ ]
sub_missing_keys : List [ str ] = [ ]
error_msgs : List [ str ] = [ ]
# copy state_dict so _load_from_state_dict can modify it
2023-09-19 06:20:26 +00:00
metadata = getattr ( state_dict , " _metadata " , None )
2023-04-06 08:23:39 +00:00
state_dict = OrderedDict ( state_dict )
if metadata is not None :
state_dict . _metadata = metadata
2023-05-05 06:37:21 +00:00
def load ( module : nn . Module , state_dict , prefix = " " , load_sub_module : bool = True ) :
2023-04-06 08:23:39 +00:00
local_metadata = { } if metadata is None else metadata . get ( prefix [ : - 1 ] , { } )
2023-05-05 06:37:21 +00:00
args = ( state_dict , prefix , local_metadata , True , sub_missing_keys , [ ] , error_msgs )
2023-04-06 08:23:39 +00:00
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len ( [ key for key in state_dict if key . startswith ( prefix ) ] ) > 0 :
module . _load_from_state_dict ( * args )
2023-05-05 06:37:21 +00:00
if load_sub_module :
for name , child in module . _modules . items ( ) :
if child is not None :
load ( child , state_dict , prefix + name + " . " )
2023-04-06 08:23:39 +00:00
2023-05-05 06:37:21 +00:00
load ( model , state_dict , " " , load_sub_module )
2023-04-06 08:23:39 +00:00
del load
2023-05-05 06:37:21 +00:00
missing_keys = missing_keys . append ( sub_missing_keys )
2023-04-06 08:23:39 +00:00
if strict :
if len ( unexpected_keys ) > 0 :
2023-09-19 06:20:26 +00:00
error_msgs = " Unexpected key(s) in state_dict: {} . " . format (
" , " . join ( ' " {} " ' . format ( k ) for k in unexpected_keys )
)
raise RuntimeError (
" Error(s) in loading state_dict for {} : \n \t {} " . format ( model . __class__ . __name__ , " \n \t " . join ( error_msgs ) )
)
2023-05-18 12:05:59 +00:00
2023-06-15 07:21:26 +00:00
def load_param_groups_into_optimizer ( optimizer : Optimizer , param_group_path : str ) - > dict :
"""
Load information of param_groups into an initialized optimizer .
"""
# Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices.
2023-09-19 06:20:26 +00:00
saved_groups = torch . load ( param_group_path , map_location = torch . device ( " cpu " ) )
2023-06-15 07:21:26 +00:00
if not isinstance ( saved_groups , List ) :
2023-09-19 06:20:26 +00:00
raise ValueError ( f " The param_groups saved at { param_group_path } is not of List type " )
2023-06-15 07:21:26 +00:00
# The params in param_groups are in the form of pytorch tensors.
# For more details, please view source code of Optimizer class in pytorch.
param_groups = optimizer . param_groups
# Check the compatibility of saved_groups and param_groups.
if len ( param_groups ) != len ( saved_groups ) :
raise ValueError ( " loaded state dict has a different number of original parameter groups " )
2023-09-19 06:20:26 +00:00
param_lens = ( len ( g [ " params " ] ) for g in param_groups )
saved_lens = ( len ( g [ " params " ] ) for g in saved_groups )
2023-06-15 07:21:26 +00:00
if any ( p_len != s_len for p_len , s_len in zip ( param_lens , saved_lens ) ) :
2023-09-19 06:20:26 +00:00
raise ValueError (
" loaded state dict contains a parameter group " " that doesn ' t match the size of optimizer ' s group "
)
2023-06-15 07:21:26 +00:00
# Creating mapping from id to parameters.
id_map = {
2023-09-19 06:20:26 +00:00
old_id : p
for old_id , p in zip (
chain . from_iterable ( ( g [ " params " ] for g in saved_groups ) ) ,
chain . from_iterable ( ( g [ " params " ] for g in param_groups ) ) ,
)
2023-06-15 07:21:26 +00:00
}
# Update parameter groups, setting their 'params' value.
def update_group ( group , new_group ) :
2023-09-19 06:20:26 +00:00
new_group [ " params " ] = group [ " params " ]
2023-06-15 07:21:26 +00:00
return new_group
updated_groups = [ update_group ( g , ng ) for g , ng in zip ( param_groups , saved_groups ) ]
2023-09-19 06:20:26 +00:00
optimizer . __dict__ . update ( { " param_groups " : updated_groups } )
2023-06-15 07:21:26 +00:00
return id_map
2023-08-31 06:50:47 +00:00
def load_states_into_optimizer ( optimizer : Optimizer , state_dict : dict , id_map : dict , strict : bool = False ) :
2023-06-15 07:21:26 +00:00
r """ Copies states from `state_dict` into an Optimizer object.
Args :
optimizer ( Optimizer ) : An initialized Optimizer object to be loaded
2023-08-31 06:50:47 +00:00
state_dict ( dict ) : A mapping from tensor index ( an integer )
2023-06-15 07:21:26 +00:00
to its states to be loaded ( a mapping from state name to a tensor ) .
2023-08-31 06:50:47 +00:00
id_map ( dict ) : A mapping from tensor index ( an integer )
2023-06-15 07:21:26 +00:00
to its corresponding parameter ( a tensor ) whose states will be updated .
2023-08-31 06:50:47 +00:00
strict ( bool , optional ) : If set to True , only load the parameters with its id in id_map . Defaults to False .
2023-06-15 07:21:26 +00:00
"""
2023-08-31 06:50:47 +00:00
# Ensure that the keys of state_dict are integers.
state_dict = { int ( k ) : v for k , v in state_dict . items ( ) }
2023-06-15 07:21:26 +00:00
def cast ( param , value , key = None ) :
r """ Make a deep copy of value, casting all tensors to device of param. """
if isinstance ( value , torch . Tensor ) :
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
2023-09-19 06:20:26 +00:00
if key != " step " :
2023-06-15 07:21:26 +00:00
if param . is_floating_point ( ) :
value = value . to ( param . dtype )
value = value . to ( param . device )
return value
elif isinstance ( value , dict ) :
return { k : cast ( param , v , key = k ) for k , v in value . items ( ) }
elif isinstance ( value , container_abcs . Iterable ) :
return type ( value ) ( cast ( param , v ) for v in value )
else :
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
new_states = defaultdict ( dict )
for k , v in state_dict . items ( ) :
if k in id_map :
param = id_map [ k ]
new_states [ param ] = cast ( param , v )
2023-08-31 06:50:47 +00:00
elif not strict :
2023-06-15 07:21:26 +00:00
new_states [ k ] = v
2023-06-16 06:14:05 +00:00
optimizer . state . update ( new_states )
2023-06-15 07:21:26 +00:00
def sharded_optimizer_loading_epilogue ( optimizer : Optimizer ) :
2023-06-16 06:14:05 +00:00
r """ Do the cleaning up work after state_dict has been loaded into optimizer
Args :
optimizer ( Optimizer ) : An optimizer object whose state has just been loaded .
"""
2023-06-15 07:21:26 +00:00
# Do the cleaning up as in src code of Pytorch.
2023-10-07 02:45:52 +00:00
if Version ( torch . __version__ ) > = Version ( " 2.0.0 " ) :
optimizer . _patch_step_function ( ) # To support multiprocessing pickle/unpickle
else :
optimizer . _hook_for_profile ( ) # To support multiprocessing pickle/unpickle.
2023-09-19 06:20:26 +00:00
optimizer . defaults . setdefault ( " differentiable " , False )
2023-06-15 07:21:26 +00:00
2023-04-04 07:23:01 +00:00
def has_index_file ( checkpoint_path : str ) - > Tuple [ bool , Optional [ Path ] ] :
"""
Check whether the checkpoint has an index file .
Args :
checkpoint_path ( str ) : path to the checkpoint .
Returns :
Tuple [ bool , Optional [ Path ] ] : a tuple of ( has_index_file , index_file_path )
"""
checkpoint_path = Path ( checkpoint_path )
if checkpoint_path . is_file ( ) :
# check if it is .index.json
2023-04-12 08:02:17 +00:00
reg = re . compile ( " (.*?).index(( \ ..*)?).json " )
if reg . fullmatch ( checkpoint_path . name ) is not None :
2023-04-04 07:23:01 +00:00
return True , checkpoint_path
else :
return False , None
elif checkpoint_path . is_dir ( ) :
# check if there is only one a file ending with .index.json in this directory
2023-09-19 06:20:26 +00:00
index_files = list ( checkpoint_path . glob ( " *.index.*json " ) )
2023-04-04 07:23:01 +00:00
# if we found a .index.json file, make sure there is only one
if len ( index_files ) > 0 :
2023-09-19 06:20:26 +00:00
assert (
len ( index_files ) == 1
) , f " Expected to find one .index.json file in { checkpoint_path } , but found { len ( index_files ) } "
2023-04-04 07:23:01 +00:00
if len ( index_files ) == 1 :
return True , index_files [ 0 ]
else :
return False , None
2023-05-18 12:05:59 +00:00
else :
2023-09-19 06:20:26 +00:00
raise RuntimeError ( f " Invalid checkpoint path { checkpoint_path } . Expected a file or a directory. " )
2023-04-04 07:23:01 +00:00
def load_state_dict ( checkpoint_file_path : Path ) :
"""
Load state dict from checkpoint .
Args :
checkpoint_file_path ( Path ) : path to the checkpoint file .
Returns :
dict : state dict .
"""
2023-09-19 06:20:26 +00:00
assert not is_dtensor_checkpoint (
checkpoint_file_path
) , f " Cannot load state dict from dtensor checkpoint { checkpoint_file_path } , you should convert the distributed tensors to gathered tensors with our CLI offline. "
2023-04-04 07:23:01 +00:00
if is_safetensor_checkpoint ( checkpoint_file_path ) :
2023-09-19 06:20:26 +00:00
assert (
is_safetensors_available ( )
) , f " Cannot load state dict from safetensor checkpoint { checkpoint_file_path } , because safetensors is not available. Please install safetensors first with pip install safetensors. "
2023-04-04 07:23:01 +00:00
# load with safetensors
from safetensors import safe_open
2023-09-19 06:20:26 +00:00
2023-04-04 07:23:01 +00:00
state_dict = { }
with safe_open ( checkpoint_file_path , framework = " pt " , device = " cpu " ) as f :
for k in f . keys ( ) :
state_dict [ k ] = f . get_tensor ( k )
return state_dict
else :
# load with torch
2023-09-19 06:20:26 +00:00
return torch . load ( checkpoint_file_path , map_location = torch . device ( " cpu " ) )
2023-04-12 08:02:17 +00:00
2023-06-15 07:21:26 +00:00
def add_prefix ( weights_name : str , prefix : Optional [ str ] = None ) - > str :
if prefix is not None and len ( prefix ) > 0 :
2023-04-12 08:02:17 +00:00
splits = weights_name . split ( " . " )
2023-06-15 07:21:26 +00:00
splits = splits [ : - 1 ] + [ prefix ] + splits [ - 1 : ]
2023-04-12 08:02:17 +00:00
weights_name = " . " . join ( splits )
return weights_name
2023-05-05 06:37:21 +00:00
2023-06-15 07:21:26 +00:00
def get_model_base_filenames ( prefix : str = None , use_safetensors : bool = False ) :
2023-05-18 12:05:59 +00:00
"""
2023-06-15 07:21:26 +00:00
generate base model weight filenames
2023-05-18 12:05:59 +00:00
"""
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
2023-06-15 07:21:26 +00:00
weights_name = add_prefix ( weights_name , prefix )
2023-05-18 12:05:59 +00:00
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
2023-06-15 07:21:26 +00:00
save_index_file = add_prefix ( save_index_file , prefix )
2023-05-05 06:37:21 +00:00
2023-05-18 12:05:59 +00:00
return weights_name , save_index_file
2023-05-05 06:37:21 +00:00
2023-06-15 07:21:26 +00:00
def get_optimizer_base_filenames ( prefix : str = None ) :
"""
generate base optimizer state filenames
"""
states_name = STATES_NAME
states_name = add_prefix ( states_name , prefix )
save_index_file = STATES_INDEX_NAME
save_index_file = add_prefix ( save_index_file , prefix )
param_group_file = GROUP_FILE_NAME
param_group_file = add_prefix ( param_group_file , prefix )
return states_name , save_index_file , param_group_file
2023-05-05 06:37:21 +00:00
def get_shard_filename ( weights_name : str , idx : int ) :
"""
get shard file name
"""
shard_file = weights_name . replace ( " .bin " , f " - { idx + 1 : 05d } .bin " )
2023-09-01 09:40:01 +00:00
shard_file = shard_file . replace ( " .safetensors " , f " - { idx + 1 : 05d } .safetensors " )
2023-05-18 12:05:59 +00:00
return shard_file