2023-04-06 08:23:39 +00:00
# coding=utf-8
2023-05-18 12:05:59 +00:00
import re
2023-04-04 07:23:01 +00:00
from pathlib import Path
2023-05-18 12:05:59 +00:00
from typing import Iterator , List , Mapping , Optional , OrderedDict , Tuple
2023-04-04 07:23:01 +00:00
import torch
2023-04-06 08:23:39 +00:00
import torch . nn as nn
2023-05-18 12:05:59 +00:00
2023-04-06 08:23:39 +00:00
from colossalai . tensor . d_tensor . d_tensor import DTensor
SAFE_WEIGHTS_NAME = " model.safetensors "
2023-04-12 08:02:17 +00:00
WEIGHTS_NAME = " pytorch_model.bin "
2023-04-06 08:23:39 +00:00
SAFE_WEIGHTS_INDEX_NAME = " model.safetensors.index.json "
2023-04-12 08:02:17 +00:00
WEIGHTS_INDEX_NAME = " pytorch_model.bin.index.json "
2023-04-04 07:23:01 +00:00
# ======================================
# General helper functions
# ======================================
2023-05-18 12:05:59 +00:00
2023-04-04 07:23:01 +00:00
def calculate_tensor_size ( tensor : torch . Tensor ) - > float :
"""
Calculate the size of a parameter in MB . Used to compute whether a group of params exceed the shard size .
If so , a new shard should be created .
Args :
2023-05-15 03:46:25 +00:00
tensor ( torch . Tensor ) : the tensor to calculate size for .
2023-04-04 07:23:01 +00:00
Returns :
float : size of the tensor in MB .
"""
return tensor . numel ( ) * tensor . element_size ( ) / 1024 / 1024
2023-05-18 12:05:59 +00:00
2023-04-04 07:23:01 +00:00
def is_safetensors_available ( ) - > bool :
"""
Check whether safetensors is available .
Returns :
bool : whether safetensors is available .
"""
try :
import safetensors
return True
except ImportError :
return False
def is_dtensor_checkpoint ( checkpoint_file_path : str ) - > bool :
"""
Check whether the checkpoint file is a dtensor checkpoint .
Args :
checkpoint_file_path ( str ) : path to the checkpoint file .
Returns :
bool : whether the checkpoint file is a dtensor checkpoint .
"""
if checkpoint_file_path . endswith ( ' .*.safetensors ' ) or checkpoint_file_path . endswith ( ' .*.bin ' ) :
return True
else :
return False
def is_safetensor_checkpoint ( checkpoint_file_path : str ) - > bool :
"""
Check whether the checkpoint file is a safetensor checkpoint .
Args :
checkpoint_file_path ( str ) : path to the checkpoint file .
Returns :
bool : whether the checkpoint file is a safetensor checkpoint .
"""
if checkpoint_file_path . endswith ( ' .safetensors ' ) :
return True
else :
return False
2023-04-06 08:23:39 +00:00
# ======================================
# Helper functions for saving shard file
# ======================================
2023-05-05 06:37:21 +00:00
def shard_checkpoint ( state_dict : torch . Tensor , max_shard_size : int = 1024 ) - > Iterator [ Tuple [ OrderedDict , int ] ] :
2023-04-06 08:23:39 +00:00
"""
Splits a model state dictionary in sub - checkpoints so that the final size of each sub - checkpoint does not exceed a
given size .
"""
current_block = { }
current_block_size = 0
for key , weight in state_dict . items ( ) :
2023-05-05 06:37:21 +00:00
ret_block = None
ret_block_size = 0
2023-04-06 08:23:39 +00:00
if type ( weight ) != DTensor :
weight_size = calculate_tensor_size ( weight )
# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size :
2023-05-05 06:37:21 +00:00
ret_block = current_block
ret_block_size = current_block_size
2023-04-06 08:23:39 +00:00
current_block = { }
current_block_size = 0
current_block [ key ] = weight
current_block_size + = weight_size
2023-05-18 12:05:59 +00:00
2023-05-05 06:37:21 +00:00
if ret_block != None :
yield ret_block , ret_block_size
2023-04-06 08:23:39 +00:00
2023-05-05 06:37:21 +00:00
yield current_block , current_block_size
2023-04-06 08:23:39 +00:00
2023-05-18 12:05:59 +00:00
def load_shard_state_dict ( checkpoint_file : Path , use_safetensors : bool = False ) :
2023-04-06 08:23:39 +00:00
"""
load shard state dict into model
"""
if use_safetensors and not checkpoint_file . suffix == " .safetensors " :
raise Exception ( " load the model using `safetensors`, but no file endwith .safetensors " )
if use_safetensors :
from safetensors . torch import load_file as safe_load_file
2023-05-18 12:05:59 +00:00
from safetensors . torch import safe_open
2023-04-06 08:23:39 +00:00
with safe_open ( checkpoint_file , framework = " pt " ) as f :
metadata = f . metadata ( )
if metadata [ " format " ] != " pt " :
raise NotImplementedError (
2023-05-18 12:05:59 +00:00
f " Conversion from a { metadata [ ' format ' ] } safetensors archive to PyTorch is not implemented yet. " )
2023-04-06 08:23:39 +00:00
return safe_load_file ( checkpoint_file )
else :
return torch . load ( checkpoint_file )
2023-05-18 12:05:59 +00:00
def load_state_dict_into_model ( model : nn . Module ,
state_dict : torch . Tensor ,
missing_keys : List ,
strict : bool = False ,
load_sub_module : bool = True ) :
2023-04-06 08:23:39 +00:00
r """ Copies parameters and buffers from :attr:`state_dict` into
2023-05-18 12:05:59 +00:00
this module and its descendants .
2023-04-06 08:23:39 +00:00
Args :
state_dict ( dict ) : a dict containing parameters and
persistent buffers .
"""
if not isinstance ( state_dict , Mapping ) :
raise TypeError ( " Expected state_dict to be dict-like, got {} . " . format ( type ( state_dict ) ) )
unexpected_keys : List [ str ] = [ ]
sub_missing_keys : List [ str ] = [ ]
error_msgs : List [ str ] = [ ]
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr ( state_dict , ' _metadata ' , None )
state_dict = OrderedDict ( state_dict )
if metadata is not None :
state_dict . _metadata = metadata
2023-05-05 06:37:21 +00:00
def load ( module : nn . Module , state_dict , prefix = " " , load_sub_module : bool = True ) :
2023-04-06 08:23:39 +00:00
local_metadata = { } if metadata is None else metadata . get ( prefix [ : - 1 ] , { } )
2023-05-05 06:37:21 +00:00
args = ( state_dict , prefix , local_metadata , True , sub_missing_keys , [ ] , error_msgs )
2023-04-06 08:23:39 +00:00
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len ( [ key for key in state_dict if key . startswith ( prefix ) ] ) > 0 :
module . _load_from_state_dict ( * args )
2023-05-05 06:37:21 +00:00
if load_sub_module :
for name , child in module . _modules . items ( ) :
if child is not None :
load ( child , state_dict , prefix + name + " . " )
2023-04-06 08:23:39 +00:00
2023-05-05 06:37:21 +00:00
load ( model , state_dict , " " , load_sub_module )
2023-04-06 08:23:39 +00:00
del load
2023-05-05 06:37:21 +00:00
missing_keys = missing_keys . append ( sub_missing_keys )
2023-04-06 08:23:39 +00:00
if strict :
if len ( unexpected_keys ) > 0 :
2023-05-18 12:05:59 +00:00
error_msgs = ' Unexpected key(s) in state_dict: {} . ' . format ( ' , ' . join (
' " {} " ' . format ( k ) for k in unexpected_keys ) )
2023-04-06 08:23:39 +00:00
raise RuntimeError ( ' Error(s) in loading state_dict for {} : \n \t {} ' . format (
2023-05-18 12:05:59 +00:00
model . __class__ . __name__ , " \n \t " . join ( error_msgs ) ) )
2023-04-04 07:23:01 +00:00
# ======================================
# Helper functions for saving state dict
# ======================================
def save_state_dict ( state_dict : dict , checkpoint_file_path : str , use_safetensors : bool ) - > None :
"""
Save state dict to checkpoint .
Args :
state_dict ( dict ) : state dict .
checkpoint_file_path ( str ) : path to the checkpoint file .
use_safetensors ( bool ) : whether to use safetensors to save the checkpoint .
"""
if use_safetensors :
assert is_safetensors_available ( ) , " safetensors is not available. "
assert checkpoint_file_path . endswith ( ' .safetensors ' ) , \
" safetensors only supports .safetensors suffix for checkpoint file. "
2023-04-06 08:23:39 +00:00
from safetensors . torch import save_file as safe_save_file
safe_save_file ( state_dict , checkpoint_file_path , metadata = { " format " : " pt " } )
2023-04-04 07:23:01 +00:00
else :
torch . save ( state_dict , checkpoint_file_path )
def save_dtensor ( name : str , tensor : torch . Tensor , index_file : " CheckpointIndexFile " , use_safetensors : bool ) - > None :
"""
Save distributed tensor to checkpoint . This checkpoint will be a dictionary which contains
only one tensor .
Args :
tensor ( Tensor ) : tensor to be saved .
index_file ( CheckpointIndexFile ) : path to the checkpoint file .
size_per_shard ( int ) : size per shard in MB .
"""
root_path = index_file . root_path
output_root_path = root_path . joinpath ( ' dtensor ' )
# create directory
output_root_path . mkdir ( exist_ok = True )
# save tensor to this directory
# TODO(YuliangLiu): get index of the tensor shard
# e.g. index =
index = 0
# save tensor to file
ckpt_file_name = generate_dtensor_file_name ( name , index , use_safetensors )
ckpt_file_path = output_root_path . joinpath ( ckpt_file_name )
# dtensor ckpt file always contains only one tensor
state_dict = { name : tensor }
save_state_dict ( state_dict , str ( ckpt_file_path ) , use_safetensors )
# update the weight map
# * means all shards
ckpt_file_name_in_weight_map = ' dtensor/ ' + generate_dtensor_file_name ( name , ' * ' , use_safetensors )
index_file . append_weight_map ( name , ckpt_file_name_in_weight_map )
def get_checkpoint_file_suffix ( use_safetensors : bool ) - > str :
"""
Get checkpoint file suffix .
Args :
use_safetensors ( bool ) : whether to use safetensors to save the checkpoint .
Returns :
str : checkpoint file suffix .
"""
if use_safetensors :
return ' .safetensors '
else :
return ' .bin '
def generate_checkpoint_shard_file_name ( index : int ,
total_number : int ,
use_safetensors : bool ,
prefix : str = None ) - > str :
"""
Generate checkpoint shard file name .
Args :
index ( int ) : index of the shard .
total_number ( int ) : total number of shards .
use_safetensors ( bool ) : whether to use safetensors to save the checkpoint .
prefix ( str ) : prefix of the shard file name . Default : None .
Returns :
str : checkpoint shard file name .
"""
suffix = get_checkpoint_file_suffix ( use_safetensors )
if prefix is None :
return f " { index : 05d } -of- { total_number : 05d } . { suffix } "
else :
return f " { prefix } - { index : 05d } -of- { total_number : 05d } . { suffix } "
def generate_dtensor_file_name ( param_name : str , index : int , use_safetensors : bool ) - > str :
"""
Generate dtensor file name .
Args :
param_name ( str ) : name of the distributed parameter .
index ( int ) : index of the shard .
use_safetensors ( bool ) : whether to use safetensors to save the checkpoint .
Returns :
str : dtensor file name .
"""
suffix = get_checkpoint_file_suffix ( use_safetensors )
return f ' { param_name } . { index } . { suffix } '
def save_state_dict_as_shard (
state_dict : dict ,
checkpoint_path : str ,
index : int ,
total_number : int ,
use_safetensors : bool ,
prefix : str = None ,
) - > None :
"""
Save state dict as shard .
Args :
state_dict ( dict ) : state dict .
checkpoint_path ( str ) : path to the checkpoint file .
index ( int ) : index of the shard .
total_number ( int ) : total number of shards .
prefix ( str ) : prefix of the shard file name .
use_safetensors ( bool ) : whether to use safetensors to save the checkpoint .
"""
# generate the shard name
shard_file_name = generate_checkpoint_shard_file_name ( index , total_number , use_safetensors , prefix )
shard_file_path = Path ( checkpoint_path ) . joinpath ( shard_file_name ) . absolute ( )
# save the shard
save_state_dict ( state_dict , str ( shard_file_path ) , use_safetensors )
# ========================================
# Helper functions for loading state dict
# ========================================
def has_index_file ( checkpoint_path : str ) - > Tuple [ bool , Optional [ Path ] ] :
"""
Check whether the checkpoint has an index file .
Args :
checkpoint_path ( str ) : path to the checkpoint .
Returns :
Tuple [ bool , Optional [ Path ] ] : a tuple of ( has_index_file , index_file_path )
"""
checkpoint_path = Path ( checkpoint_path )
if checkpoint_path . is_file ( ) :
# check if it is .index.json
2023-04-12 08:02:17 +00:00
reg = re . compile ( " (.*?).index(( \ ..*)?).json " )
if reg . fullmatch ( checkpoint_path . name ) is not None :
2023-04-04 07:23:01 +00:00
return True , checkpoint_path
else :
return False , None
elif checkpoint_path . is_dir ( ) :
# check if there is only one a file ending with .index.json in this directory
2023-04-12 08:02:17 +00:00
index_files = list ( checkpoint_path . glob ( ' *.index.*json ' ) )
2023-04-04 07:23:01 +00:00
# if we found a .index.json file, make sure there is only one
if len ( index_files ) > 0 :
assert len (
index_files
) == 1 , f ' Expected to find one .index.json file in { checkpoint_path } , but found { len ( index_files ) } '
if len ( index_files ) == 1 :
return True , index_files [ 0 ]
else :
return False , None
2023-05-18 12:05:59 +00:00
else :
raise RuntimeError ( f ' Invalid checkpoint path { checkpoint_path } . Expected a file or a directory. ' )
2023-04-04 07:23:01 +00:00
def load_state_dict ( checkpoint_file_path : Path ) :
"""
Load state dict from checkpoint .
Args :
checkpoint_file_path ( Path ) : path to the checkpoint file .
Returns :
dict : state dict .
"""
assert not is_dtensor_checkpoint ( checkpoint_file_path ) , \
f ' Cannot load state dict from dtensor checkpoint { checkpoint_file_path } , you should convert the distributed tensors to gathered tensors with our CLI offline. '
if is_safetensor_checkpoint ( checkpoint_file_path ) :
assert is_safetensors_available ( ) , \
f ' Cannot load state dict from safetensor checkpoint { checkpoint_file_path } , because safetensors is not available. Please install safetensors first with pip install safetensors. '
# load with safetensors
from safetensors import safe_open
state_dict = { }
with safe_open ( checkpoint_file_path , framework = " pt " , device = " cpu " ) as f :
for k in f . keys ( ) :
state_dict [ k ] = f . get_tensor ( k )
return state_dict
else :
# load with torch
return torch . load ( checkpoint_file_path )
2023-04-12 08:02:17 +00:00
def add_variant ( weights_name : str , variant : Optional [ str ] = None ) - > str :
if variant is not None and len ( variant ) > 0 :
splits = weights_name . split ( " . " )
splits = splits [ : - 1 ] + [ variant ] + splits [ - 1 : ]
weights_name = " . " . join ( splits )
return weights_name
2023-05-05 06:37:21 +00:00
2023-05-18 12:05:59 +00:00
def get_base_filenames ( variant : str = None , use_safetensors : bool = False ) :
"""
generate base weight filenames
"""
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_variant ( weights_name , variant )
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = add_variant ( save_index_file , variant )
2023-05-05 06:37:21 +00:00
2023-05-18 12:05:59 +00:00
return weights_name , save_index_file
2023-05-05 06:37:21 +00:00
def get_shard_filename ( weights_name : str , idx : int ) :
"""
get shard file name
"""
shard_file = weights_name . replace ( " .bin " , f " - { idx + 1 : 05d } .bin " )
shard_file = shard_file . replace ( " .safetensors " , f " - { idx + 1 : 05d } .safetensors " )
2023-05-18 12:05:59 +00:00
return shard_file