2024-10-14 07:32:16 +00:00
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
import json
from dataclasses import asdict , dataclass
2024-11-14 03:38:10 +00:00
from typing import Dict , List , Optional , Tuple
2024-10-14 07:32:16 +00:00
import torch
2024-11-18 09:52:24 +00:00
from safetensors . torch import _TYPES , load_file , safe_open
2024-10-14 07:32:16 +00:00
2024-10-14 09:41:25 +00:00
try :
from tensornvme . async_file_io import AsyncFileWriter
except ModuleNotFoundError :
raise ModuleNotFoundError ( " Please install tensornvme to use NVMeOptimizer " )
2024-10-14 07:32:16 +00:00
_TYPES_INV = { v : k for k , v in _TYPES . items ( ) }
2024-11-20 08:36:37 +00:00
import io
from torch . distributed . distributed_c10d import _pickler , _unpickler
2024-12-02 03:29:19 +00:00
ASYNC_WRITE_ENTRIES = 32
2024-11-20 08:36:37 +00:00
def _object_to_tensor ( obj , device ) :
f = io . BytesIO ( )
_pickler ( f ) . dump ( obj )
byte_storage = torch . ByteStorage . _from_buffer ( f . getvalue ( ) ) # type: ignore[attr-defined]
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
# Otherwise, it will casue 100X slowdown.
# See: https://github.com/pytorch/pytorch/issues/65696
byte_tensor = torch . ByteTensor ( byte_storage ) . to ( device )
return byte_tensor
def _tensor_to_object ( tensor , tensor_size ) :
tensor = tensor . cpu ( )
buf = tensor . numpy ( ) . tobytes ( ) [ : tensor_size ]
return _unpickler ( io . BytesIO ( buf ) ) . load ( )
2024-10-14 07:32:16 +00:00
@dataclass
class TensorInfo :
dtype : str
shape : List [ int ]
data_offsets : Tuple [ int , int ]
@dataclass
class PreparedData :
n : int
header_bytes : bytes
offset : int
2024-11-20 08:36:37 +00:00
def _cast_to_tensor ( obj ) :
if isinstance ( obj , torch . Tensor ) :
return obj
return _object_to_tensor ( obj , " cpu " )
def _cast_to_object ( tensor : torch . Tensor ) :
return _tensor_to_object ( tensor , tensor . numel ( ) * tensor . element_size ( ) )
def _flatten_optim_state_dict ( state_dict : dict , seperator : str = " . " ) - > Tuple [ dict , Optional [ dict ] ] :
flat_dict = { }
non_tensor_keys = [ ]
if " state " in state_dict :
# 3-level dict
states = state_dict [ " state " ]
else :
# 2-level dict, usually for optimizer state dict shard
states = state_dict
for idx , d in states . items ( ) :
for k , v in d . items ( ) :
2024-12-23 02:24:22 +00:00
if v is None :
continue
2024-11-20 08:36:37 +00:00
nested_key = f " state { seperator } { idx } { seperator } { k } "
if not isinstance ( v , torch . Tensor ) :
non_tensor_keys . append ( nested_key )
flat_dict [ nested_key ] = _cast_to_tensor ( v )
if " param_groups " in state_dict :
flat_dict [ " param_groups " ] = _cast_to_tensor ( state_dict [ " param_groups " ] )
non_tensor_keys . append ( " param_groups " )
if len ( non_tensor_keys ) > 0 :
metadata = { " non_tensor_keys " : non_tensor_keys }
else :
metadata = None
return flat_dict , metadata
def _unflatten_optim_state_dict ( flat_dict : dict , metadata : Optional [ dict ] = None , seperator : str = " . " ) :
state_dict = { }
2024-12-23 02:24:22 +00:00
if metadata is not None and " non_tensor_keys " in metadata :
2024-11-20 08:36:37 +00:00
non_tensor_keys = json . loads ( metadata [ " non_tensor_keys " ] )
else :
non_tensor_keys = [ ]
flat_dict = { k : _cast_to_object ( v ) if k in non_tensor_keys else v for k , v in flat_dict . items ( ) }
if " param_groups " in flat_dict :
# 3-level dict
state_dict [ " param_groups " ] = flat_dict . pop ( " param_groups " )
state_dict [ " state " ] = { }
states = state_dict [ " state " ]
else :
# 2-level dict, usually for optimizer state dict shard
states = state_dict
for k , v in flat_dict . items ( ) :
parts = k . split ( seperator )
assert len ( parts ) == 3 and parts [ 0 ] == " state "
idx = int ( parts [ 1 ] )
key = parts [ 2 ]
if idx not in states :
states [ idx ] = { }
states [ idx ] [ key ] = v
return state_dict
2024-11-18 09:52:24 +00:00
def prepare (
data : Dict [ str , torch . Tensor ] , metadata : Optional [ Dict [ str , str ] ] = None
) - > Tuple [ PreparedData , List [ torch . Tensor ] , List [ str ] ] :
if metadata is not None :
assert isinstance ( metadata , dict )
for k , v in metadata . items ( ) :
metadata [ k ] = json . dumps ( v )
assert isinstance ( k , str )
assert isinstance ( metadata [ k ] , str )
2024-10-14 07:32:16 +00:00
tensors = [ ]
2024-11-14 03:38:10 +00:00
tensor_keys = [ ]
2024-11-18 09:52:24 +00:00
header = { }
2024-10-14 07:32:16 +00:00
offset = 0
2024-11-18 09:52:24 +00:00
2024-12-23 02:24:22 +00:00
header_metadata = { " format " : " pt " }
2024-11-18 09:52:24 +00:00
if metadata is not None :
2024-12-23 02:24:22 +00:00
header_metadata . update ( metadata )
header [ " __metadata__ " ] = header_metadata
2024-11-18 09:52:24 +00:00
2024-11-15 10:19:16 +00:00
for name , tensor in data . items ( ) :
2024-10-14 07:32:16 +00:00
n = tensor . numel ( ) * tensor . element_size ( )
tensor_info = TensorInfo (
dtype = _TYPES_INV [ tensor . dtype ] , shape = list ( tensor . shape ) , data_offsets = ( offset , offset + n )
)
offset + = n
2024-11-18 09:52:24 +00:00
header [ name ] = asdict ( tensor_info )
2024-10-14 07:32:16 +00:00
tensors . append ( tensor )
2024-11-14 03:38:10 +00:00
tensor_keys . append ( name )
2024-10-14 07:32:16 +00:00
2024-11-18 09:52:24 +00:00
header_buf = json . dumps ( header ) . encode ( " utf-8 " )
2024-10-14 07:32:16 +00:00
2024-11-18 09:52:24 +00:00
extra = ( 8 - len ( header_buf ) % 8 ) % 8
header_buf + = b " " * extra
2024-10-14 07:32:16 +00:00
2024-11-18 09:52:24 +00:00
n = len ( header_buf )
2024-10-14 07:32:16 +00:00
2024-11-18 09:52:24 +00:00
return PreparedData ( n = n , header_bytes = header_buf , offset = offset ) , tensors , tensor_keys
2024-10-14 09:41:25 +00:00
2024-12-02 03:29:19 +00:00
def save ( path : str , state_dict : Dict [ str , torch . Tensor ] , metadata : Optional [ Dict [ str , str ] ] = None ) - > None :
2024-11-18 09:52:24 +00:00
prepared_data , tensors , _ = prepare ( state_dict , metadata )
2024-10-14 09:41:25 +00:00
n , header_bytes , _ = prepared_data . n , prepared_data . header_bytes , prepared_data . offset
2024-12-02 03:29:19 +00:00
f_writer = AsyncFileWriter ( path , n_entries = ASYNC_WRITE_ENTRIES , backend = " pthread " , n_tasks = 2 + len ( tensors ) )
2024-10-14 09:41:25 +00:00
f_writer . write ( n . to_bytes ( 8 , byteorder = " little " ) )
f_writer . write ( header_bytes )
for tensor in tensors :
f_writer . write_raw ( tensor , tensor . data_ptr ( ) , tensor . numel ( ) * tensor . element_size ( ) , f_writer . offset )
2024-12-02 03:29:19 +00:00
return f_writer
2024-11-14 03:38:10 +00:00
2024-12-02 03:29:19 +00:00
def save_nested ( path : str , state_dict : Dict [ str , torch . Tensor ] ) - > None :
2024-11-20 08:36:37 +00:00
flatten_data , metadata = _flatten_optim_state_dict ( state_dict )
2024-12-02 03:29:19 +00:00
return save ( path , flatten_data , metadata )
2024-11-18 09:52:24 +00:00
2024-11-14 03:38:10 +00:00
def move_and_save (
2024-12-02 03:29:19 +00:00
path : str ,
2024-11-14 03:38:10 +00:00
state_dict : Dict [ str , torch . Tensor ] ,
state_dict_pinned : Optional [ Dict [ str , torch . Tensor ] ] = None ,
2024-12-23 02:24:22 +00:00
metadata : Optional [ Dict [ str , str ] ] = None ,
2024-11-14 03:38:10 +00:00
) - > None :
2024-12-23 02:24:22 +00:00
prepared_data , _ , tensor_keys = prepare ( state_dict , metadata )
2024-11-14 03:38:10 +00:00
n , header_bytes , _ = prepared_data . n , prepared_data . header_bytes , prepared_data . offset
2024-12-02 03:29:19 +00:00
f_writer = AsyncFileWriter ( path , n_entries = ASYNC_WRITE_ENTRIES , backend = " pthread " , n_tasks = 2 + len ( tensor_keys ) )
2024-11-14 03:38:10 +00:00
f_writer . write ( n . to_bytes ( 8 , byteorder = " little " ) )
f_writer . write ( header_bytes )
f_writer . register_h2d ( len ( tensor_keys ) )
for name in tensor_keys :
if state_dict_pinned :
f_writer . write_tensor ( state_dict [ name ] , state_dict_pinned [ name ] )
else :
f_writer . write_tensor ( state_dict [ name ] )
2024-12-02 03:29:19 +00:00
return f_writer
2024-11-18 09:52:24 +00:00
2024-12-23 02:24:22 +00:00
def load_flat ( checkpoint_path , seperator : str = " . " ) :
2024-11-18 09:52:24 +00:00
with safe_open ( checkpoint_path , framework = " pt " ) as f :
metadata = f . metadata ( )
state_dict_load = load_file ( checkpoint_path )
2024-12-23 02:24:22 +00:00
state_dict = _unflatten_optim_state_dict ( state_dict_load , metadata , seperator )
2024-11-20 08:36:37 +00:00
return state_dict