@ -1,6 +1,5 @@
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
import json
import json
import warnings
from dataclasses import asdict , dataclass
from dataclasses import asdict , dataclass
from typing import Dict , List , Optional , Tuple
from typing import Dict , List , Optional , Tuple
@ -12,6 +11,26 @@ try:
except ModuleNotFoundError :
except ModuleNotFoundError :
raise ModuleNotFoundError ( " Please install tensornvme to use NVMeOptimizer " )
raise ModuleNotFoundError ( " Please install tensornvme to use NVMeOptimizer " )
_TYPES_INV = { v : k for k , v in _TYPES . items ( ) }
_TYPES_INV = { v : k for k , v in _TYPES . items ( ) }
import io
from torch . distributed . distributed_c10d import _pickler , _unpickler
def _object_to_tensor ( obj , device ) :
f = io . BytesIO ( )
_pickler ( f ) . dump ( obj )
byte_storage = torch . ByteStorage . _from_buffer ( f . getvalue ( ) ) # type: ignore[attr-defined]
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
# Otherwise, it will casue 100X slowdown.
# See: https://github.com/pytorch/pytorch/issues/65696
byte_tensor = torch . ByteTensor ( byte_storage ) . to ( device )
return byte_tensor
def _tensor_to_object ( tensor , tensor_size ) :
tensor = tensor . cpu ( )
buf = tensor . numpy ( ) . tobytes ( ) [ : tensor_size ]
return _unpickler ( io . BytesIO ( buf ) ) . load ( )
@dataclass
@dataclass
@ -28,49 +47,68 @@ class PreparedData:
offset : int
offset : int
def flatten_dict ( nested_dict , parent_key = " " , separator = " ^ " ) :
def _cast_to_tensor ( obj ) :
"""
if isinstance ( obj , torch . Tensor ) :
Flatten a nested dictionary , generating a flattened dictionary where the keys are joined by the specified separator .
return obj
return _object_to_tensor ( obj , " cpu " )
nested_dict : The input nested dictionary .
parent_key : The parent key currently being processed .
separator : The separator used to join keys , default is ' _ ' , but can be customized to another symbol . : return : A flattened dictionary . "
"""
items = [ ]
for k , v in nested_dict . items ( ) :
new_key = f " { parent_key } { separator } { k } " if parent_key else str ( k )
if isinstance ( v , dict ) :
items . extend ( flatten_dict ( v , new_key , separator ) . items ( ) )
else :
v = torch . tensor ( v , dtype = torch . float16 ) if not isinstance ( v , torch . Tensor ) else v
items . append ( ( new_key , v ) )
return dict ( items )
def _cast_to_object ( tensor : torch . Tensor ) :
return _tensor_to_object ( tensor , tensor . numel ( ) * tensor . element_size ( ) )
def unflatten_dict ( flattened_dict , separator = " ^ " ) :
def _flatten_optim_state_dict ( state_dict : dict , seperator : str = " . " ) - > Tuple [ dict , Optional [ dict ] ] :
"""
flat_dict = { }
Restore a flattened dictionary back to a multi - level nested dictionary .
non_tensor_keys = [ ]
if " state " in state_dict :
# 3-level dict
states = state_dict [ " state " ]
else :
# 2-level dict, usually for optimizer state dict shard
states = state_dict
for idx , d in states . items ( ) :
for k , v in d . items ( ) :
nested_key = f " state { seperator } { idx } { seperator } { k } "
if not isinstance ( v , torch . Tensor ) :
non_tensor_keys . append ( nested_key )
flat_dict [ nested_key ] = _cast_to_tensor ( v )
if " param_groups " in state_dict :
flat_dict [ " param_groups " ] = _cast_to_tensor ( state_dict [ " param_groups " ] )
non_tensor_keys . append ( " param_groups " )
if len ( non_tensor_keys ) > 0 :
metadata = { " non_tensor_keys " : non_tensor_keys }
else :
metadata = None
return flat_dict , metadata
flattened_dict : The flattened dictionary .
separator : The separator used during flattening , default is ' _ ' , but can be customized to another symbol . : return : The restored nested dictionary .
"""
nested_dict = { }
for key , value in flattened_dict . items ( ) :
keys = key . split ( separator )
try :
keys [ 0 ] = int ( keys [ 0 ] )
except ValueError :
warnings . warn ( f " { key [ 0 ] } can ' t convert to integer " )
d = nested_dict
for part in keys [ : - 1 ] :
if part not in d :
d [ part ] = { }
d = d [ part ]
assert isinstance ( value , torch . Tensor )
d [ keys [ - 1 ] ] = value
return nested_dict
def _unflatten_optim_state_dict ( flat_dict : dict , metadata : Optional [ dict ] = None , seperator : str = " . " ) :
state_dict = { }
if metadata is not None :
non_tensor_keys = json . loads ( metadata [ " non_tensor_keys " ] )
else :
non_tensor_keys = [ ]
flat_dict = { k : _cast_to_object ( v ) if k in non_tensor_keys else v for k , v in flat_dict . items ( ) }
if " param_groups " in flat_dict :
# 3-level dict
state_dict [ " param_groups " ] = flat_dict . pop ( " param_groups " )
state_dict [ " state " ] = { }
states = state_dict [ " state " ]
else :
# 2-level dict, usually for optimizer state dict shard
states = state_dict
for k , v in flat_dict . items ( ) :
parts = k . split ( seperator )
assert len ( parts ) == 3 and parts [ 0 ] == " state "
idx = int ( parts [ 1 ] )
key = parts [ 2 ]
if idx not in states :
states [ idx ] = { }
states [ idx ] [ key ] = v
return state_dict
def prepare (
def prepare (
@ -124,10 +162,8 @@ def save(
f_writer . write_raw ( tensor , tensor . data_ptr ( ) , tensor . numel ( ) * tensor . element_size ( ) , f_writer . offset )
f_writer . write_raw ( tensor , tensor . data_ptr ( ) , tensor . numel ( ) * tensor . element_size ( ) , f_writer . offset )
def save_nested (
def save_nested ( f_writer : AsyncFileWriter , state_dict : Dict [ str , torch . Tensor ] ) - > None :
f_writer : AsyncFileWriter , state_dict : Dict [ str , torch . Tensor ] , metadata : Optional [ Dict [ str , str ] ] = None
flatten_data , metadata = _flatten_optim_state_dict ( state_dict )
) - > None :
flatten_data = flatten_dict ( state_dict )
save ( f_writer , flatten_data , metadata )
save ( f_writer , flatten_data , metadata )
@ -154,10 +190,5 @@ def load_flat(checkpoint_path):
with safe_open ( checkpoint_path , framework = " pt " ) as f :
with safe_open ( checkpoint_path , framework = " pt " ) as f :
metadata = f . metadata ( )
metadata = f . metadata ( )
state_dict_load = load_file ( checkpoint_path )
state_dict_load = load_file ( checkpoint_path )
state_dict = unflatten_dict ( state_dict_load )
state_dict = _unflatten_optim_state_dict ( state_dict_load , metadata )
if metadata is None :
return state_dict
return state_dict
metadata = dict ( map ( lambda item : ( item [ 0 ] , json . loads ( item [ 1 ] ) ) , metadata . items ( ) ) )
combined_state_dict = { " state " : state_dict }
combined_state_dict . update ( metadata )
return combined_state_dict