2024-10-14 07:32:16 +00:00
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
import json
2024-11-18 09:52:24 +00:00
import warnings
2024-10-14 07:32:16 +00:00
from dataclasses import asdict , dataclass
2024-11-14 03:38:10 +00:00
from typing import Dict , List , Optional , Tuple
2024-10-14 07:32:16 +00:00
import torch
2024-11-18 09:52:24 +00:00
from safetensors . torch import _TYPES , load_file , safe_open
2024-10-14 07:32:16 +00:00
2024-10-14 09:41:25 +00:00
try :
from tensornvme . async_file_io import AsyncFileWriter
except ModuleNotFoundError :
raise ModuleNotFoundError ( " Please install tensornvme to use NVMeOptimizer " )
2024-10-14 07:32:16 +00:00
_TYPES_INV = { v : k for k , v in _TYPES . items ( ) }
@dataclass
class TensorInfo :
dtype : str
shape : List [ int ]
data_offsets : Tuple [ int , int ]
@dataclass
class PreparedData :
n : int
header_bytes : bytes
offset : int
2024-11-18 09:52:24 +00:00
def flatten_dict ( nested_dict , parent_key = " " , separator = " ^ " ) :
"""
Flatten a nested dictionary , generating a flattened dictionary where the keys are joined by the specified separator .
nested_dict : The input nested dictionary .
parent_key : The parent key currently being processed .
separator : The separator used to join keys , default is ' _ ' , but can be customized to another symbol . : return : A flattened dictionary . "
"""
items = [ ]
for k , v in nested_dict . items ( ) :
new_key = f " { parent_key } { separator } { k } " if parent_key else str ( k )
if isinstance ( v , dict ) :
items . extend ( flatten_dict ( v , new_key , separator ) . items ( ) )
else :
v = torch . tensor ( v , dtype = torch . float16 ) if not isinstance ( v , torch . Tensor ) else v
items . append ( ( new_key , v ) )
return dict ( items )
def unflatten_dict ( flattened_dict , separator = " ^ " ) :
"""
Restore a flattened dictionary back to a multi - level nested dictionary .
flattened_dict : The flattened dictionary .
separator : The separator used during flattening , default is ' _ ' , but can be customized to another symbol . : return : The restored nested dictionary .
"""
nested_dict = { }
for key , value in flattened_dict . items ( ) :
keys = key . split ( separator )
try :
keys [ 0 ] = int ( keys [ 0 ] )
except ValueError :
warnings . warn ( f " { key [ 0 ] } can ' t convert to integer " )
d = nested_dict
for part in keys [ : - 1 ] :
if part not in d :
d [ part ] = { }
d = d [ part ]
assert isinstance ( value , torch . Tensor )
d [ keys [ - 1 ] ] = value
return nested_dict
def prepare (
data : Dict [ str , torch . Tensor ] , metadata : Optional [ Dict [ str , str ] ] = None
) - > Tuple [ PreparedData , List [ torch . Tensor ] , List [ str ] ] :
if metadata is not None :
assert isinstance ( metadata , dict )
for k , v in metadata . items ( ) :
metadata [ k ] = json . dumps ( v )
assert isinstance ( k , str )
assert isinstance ( metadata [ k ] , str )
2024-10-14 07:32:16 +00:00
tensors = [ ]
2024-11-14 03:38:10 +00:00
tensor_keys = [ ]
2024-11-18 09:52:24 +00:00
header = { }
2024-10-14 07:32:16 +00:00
offset = 0
2024-11-18 09:52:24 +00:00
if metadata is not None :
header [ " __metadata__ " ] = metadata
2024-11-15 10:19:16 +00:00
for name , tensor in data . items ( ) :
2024-10-14 07:32:16 +00:00
n = tensor . numel ( ) * tensor . element_size ( )
tensor_info = TensorInfo (
dtype = _TYPES_INV [ tensor . dtype ] , shape = list ( tensor . shape ) , data_offsets = ( offset , offset + n )
)
offset + = n
2024-11-18 09:52:24 +00:00
header [ name ] = asdict ( tensor_info )
2024-10-14 07:32:16 +00:00
tensors . append ( tensor )
2024-11-14 03:38:10 +00:00
tensor_keys . append ( name )
2024-10-14 07:32:16 +00:00
2024-11-18 09:52:24 +00:00
header_buf = json . dumps ( header ) . encode ( " utf-8 " )
2024-10-14 07:32:16 +00:00
2024-11-18 09:52:24 +00:00
extra = ( 8 - len ( header_buf ) % 8 ) % 8
header_buf + = b " " * extra
2024-10-14 07:32:16 +00:00
2024-11-18 09:52:24 +00:00
n = len ( header_buf )
2024-10-14 07:32:16 +00:00
2024-11-18 09:52:24 +00:00
return PreparedData ( n = n , header_bytes = header_buf , offset = offset ) , tensors , tensor_keys
2024-10-14 09:41:25 +00:00
2024-11-18 09:52:24 +00:00
def save (
f_writer : AsyncFileWriter , state_dict : Dict [ str , torch . Tensor ] , metadata : Optional [ Dict [ str , str ] ] = None
) - > None :
prepared_data , tensors , _ = prepare ( state_dict , metadata )
2024-10-14 09:41:25 +00:00
n , header_bytes , _ = prepared_data . n , prepared_data . header_bytes , prepared_data . offset
f_writer . write ( n . to_bytes ( 8 , byteorder = " little " ) )
f_writer . write ( header_bytes )
for tensor in tensors :
f_writer . write_raw ( tensor , tensor . data_ptr ( ) , tensor . numel ( ) * tensor . element_size ( ) , f_writer . offset )
2024-11-14 03:38:10 +00:00
2024-11-18 09:52:24 +00:00
def save_nested (
f_writer : AsyncFileWriter , state_dict : Dict [ str , torch . Tensor ] , metadata : Optional [ Dict [ str , str ] ] = None
) - > None :
flatten_data = flatten_dict ( state_dict )
save ( f_writer , flatten_data , metadata )
2024-11-14 03:38:10 +00:00
def move_and_save (
f_writer : AsyncFileWriter ,
state_dict : Dict [ str , torch . Tensor ] ,
state_dict_pinned : Optional [ Dict [ str , torch . Tensor ] ] = None ,
) - > None :
prepared_data , _ , tensor_keys = prepare ( state_dict )
n , header_bytes , _ = prepared_data . n , prepared_data . header_bytes , prepared_data . offset
f_writer . write ( n . to_bytes ( 8 , byteorder = " little " ) )
f_writer . write ( header_bytes )
f_writer . register_h2d ( len ( tensor_keys ) )
for name in tensor_keys :
if state_dict_pinned :
f_writer . write_tensor ( state_dict [ name ] , state_dict_pinned [ name ] )
else :
f_writer . write_tensor ( state_dict [ name ] )
2024-11-18 09:52:24 +00:00
def load_flat ( checkpoint_path ) :
with safe_open ( checkpoint_path , framework = " pt " ) as f :
metadata = f . metadata ( )
state_dict_load = load_file ( checkpoint_path )
state_dict = unflatten_dict ( state_dict_load )
if metadata is None :
return state_dict
metadata = dict ( map ( lambda item : ( item [ 0 ] , json . loads ( item [ 1 ] ) ) , metadata . items ( ) ) )
combined_state_dict = { " state " : state_dict }
combined_state_dict . update ( metadata )
return combined_state_dict