@ -5,7 +5,7 @@ from collections import abc as container_abcs
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import Iterator, List , Mapping , Optional , OrderedDict , Tuple
from typing import Dict, Iterator, List , Mapping , Optional , OrderedDict , Tuple
import torch
import torch . nn as nn
@ -19,6 +19,7 @@ from colossalai.tensor.d_tensor import (
to_global ,
to_global_for_customized_distributed_tensor ,
)
from colossalai . utils . safetensors import move_and_save
SAFE_WEIGHTS_NAME = " model.safetensors "
WEIGHTS_NAME = " pytorch_model.bin "
@ -263,6 +264,71 @@ def save_state_dict_shards(
return total_size
def async_save_state_dict_shards (
sharded_state_dict : Iterator [ Tuple [ OrderedDict , int ] ] ,
checkpoint : str ,
index_file : " CheckpointIndexFile " ,
base_filename : str ,
is_master : bool ,
pinned_state_dict : Optional [ Dict [ str , torch . Tensor ] ] ,
n_write_entries : int ,
use_pp_format : bool = False ,
) - > Tuple [ int , Dict [ str , torch . Tensor ] , list ] :
"""
Save sharded state dict only on master rank , this method can be used by both model and optimizer states .
Args :
sharded_state_dict ( Iterator [ Tuple [ OrderedDict , int ] ] ) : a generator of shards , each shard contains state dict and shard size .
checkpoint ( str ) : The path of checkpoint directory as string .
index_file ( CheckpointIndexFile ) : The index file object to be updated .
base_filename ( str ) : Decides the prefix of filenames of shards .
is_master ( bool ) : Whether current rank is main process .
use_safetensors ( bool , optional ) : Whether to use safetensors to save checkpoint . Defaults to False .
use_pp_format : ( bool , optional ) : Whether to save the files in pipeline format including stage information . Defaults to False .
Returns :
int : the total size of shards
"""
from tensornvme . async_file_io import AsyncFileWriter
total_size = 0
shard_filenames = [ ]
if pinned_state_dict is None :
returned_state_dict = { }
else :
returned_state_dict = pinned_state_dict
writers = [ ]
for idx , shard_pair in enumerate ( sharded_state_dict ) :
shard , current_size = shard_pair
# Just loop over the sharder and gather to other ranks if not master
if not is_master :
del shard
continue
shard_file = get_shard_filename ( base_filename , idx )
total_size = total_size + current_size
for key in shard . keys ( ) :
index_file . append_weight_map ( key , shard_file )
checkpoint_file_path = os . path . join ( checkpoint , shard_file )
writer = AsyncFileWriter ( open ( checkpoint_file_path , " wb " ) , n_write_entries , backend = " pthread " )
writers . append ( writer )
if pinned_state_dict is not None :
sub_pinned_state_dict = { k : pinned_state_dict [ k ] for k in shard . keys ( ) }
else :
sub_pinned_state_dict = create_pinned_state_dict ( shard )
returned_state_dict . update ( sub_pinned_state_dict )
# Only save on master rank.
move_and_save ( writer , shard , sub_pinned_state_dict )
shard_filenames . append ( shard_file )
del shard
# Clean folder, deleted unneeded files.
clean_folder ( checkpoint , base_filename , shard_filenames , is_master = is_master , use_pp_format = use_pp_format )
return total_size , returned_state_dict , writers
def shard_model_checkpoint ( state_dict : torch . Tensor , max_shard_size : int = 1024 ) - > Iterator [ Tuple [ OrderedDict , int ] ] :
"""
Splits a model state dictionary in sub - checkpoints so that the final size of each sub - checkpoint does not exceed a
@ -799,3 +865,10 @@ def get_shard_filename(weights_name: str, idx: int):
shard_file = weights_name . replace ( " .bin " , f " - { idx + 1 : 05d } .bin " )
shard_file = shard_file . replace ( " .safetensors " , f " - { idx + 1 : 05d } .safetensors " )
return shard_file
def create_pinned_state_dict ( state_dict : Dict [ str , torch . Tensor ] ) :
pin_mem = dict ( )
for name , tensor in state_dict . items ( ) :
pin_mem [ name ] = torch . empty_like ( tensor , pin_memory = True , device = " cpu " )
return pin_mem