@ -8,9 +8,13 @@ from typing import Optional
import torch . nn as nn
import torch . nn as nn
from torch . optim import Optimizer
from torch . optim import Optimizer
from colossalai . utils . safetensors import move_and_save
from . checkpoint_io_base import CheckpointIO
from . checkpoint_io_base import CheckpointIO
from . index_file import CheckpointIndexFile
from . index_file import CheckpointIndexFile
from . utils import (
from . utils import (
async_save_state_dict_shards ,
create_pinned_state_dict ,
get_model_base_filenames ,
get_model_base_filenames ,
get_optimizer_base_filenames ,
get_optimizer_base_filenames ,
is_safetensors_available ,
is_safetensors_available ,
@ -40,15 +44,27 @@ class GeneralCheckpointIO(CheckpointIO):
checkpoint = load_state_dict ( checkpoint )
checkpoint = load_state_dict ( checkpoint )
model . load_state_dict ( checkpoint , strict = strict )
model . load_state_dict ( checkpoint , strict = strict )
def save_unsharded_model ( self , model : nn . Module , checkpoint : str , gather_dtensor : bool , use_safetensors : bool ) :
def save_unsharded_model (
self , model : nn . Module , checkpoint : str , gather_dtensor : bool , use_safetensors : bool , use_async : bool = False
) :
state_dict = model . state_dict ( )
state_dict = model . state_dict ( )
# TODO(FrankLeeeee): add support for gather_dtensor
# TODO(FrankLeeeee): add support for gather_dtensor
if gather_dtensor :
if gather_dtensor :
pass
pass
# save the checkpoint
if use_async :
save_state_dict ( state_dict , checkpoint , use_safetensors )
from tensornvme . async_file_io import AsyncFileWriter
writer = AsyncFileWriter ( open ( checkpoint , " wb " ) , self . N_WRITE_ENTRIES , backend = " pthread " )
if id ( model ) not in self . pinned_state_dicts :
self . pinned_state_dicts [ id ( model ) ] = create_pinned_state_dict ( state_dict )
self . async_writers . append ( writer )
move_and_save ( writer , state_dict , self . pinned_state_dicts [ id ( model ) ] )
else :
# save the checkpoint
save_state_dict ( state_dict , checkpoint , use_safetensors )
def load_sharded_optimizer ( self , optimizer : Optimizer , index_file_path : str , prefix : str ) :
def load_sharded_optimizer ( self , optimizer : Optimizer , index_file_path : str , prefix : str ) :
"""
"""
@ -151,6 +167,7 @@ class GeneralCheckpointIO(CheckpointIO):
prefix : Optional [ str ] = None ,
prefix : Optional [ str ] = None ,
max_shard_size : int = 1024 ,
max_shard_size : int = 1024 ,
use_safetensors : bool = False ,
use_safetensors : bool = False ,
use_async : bool = False ,
) :
) :
"""
"""
implement this method as it can be supported by Huggingface model ,
implement this method as it can be supported by Huggingface model ,
@ -168,16 +185,30 @@ class GeneralCheckpointIO(CheckpointIO):
weights_name , save_index_file = get_model_base_filenames ( prefix , use_safetensors )
weights_name , save_index_file = get_model_base_filenames ( prefix , use_safetensors )
index_file = CheckpointIndexFile ( checkpoint_path )
index_file = CheckpointIndexFile ( checkpoint_path )
# Save shards of optimizer states.
if use_async :
# In general cases, is_master is set to True to get the right behavior.
pinned_state_dict = self . pinned_state_dicts . get ( id ( model ) , None )
total_size = save_state_dict_shards (
total_size , new_pinned_state_dict , writers = async_save_state_dict_shards (
sharded_state_dict = state_dict_shard ,
sharded_state_dict = state_dict_shard ,
checkpoint = checkpoint_path ,
checkpoint = checkpoint_path ,
index_file = index_file ,
index_file = index_file ,
base_filename = weights_name ,
base_filename = weights_name ,
is_master = True ,
is_master = True ,
use_safetensors = use_safetensors ,
pinned_state_dict = pinned_state_dict ,
)
n_write_entries = self . N_WRITE_ENTRIES ,
)
self . pinned_state_dicts [ id ( model ) ] = new_pinned_state_dict
self . async_writers . extend ( writers )
else :
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards (
sharded_state_dict = state_dict_shard ,
checkpoint = checkpoint_path ,
index_file = index_file ,
base_filename = weights_name ,
is_master = True ,
use_safetensors = use_safetensors ,
)
index_file . append_meta_data ( " total_size " , total_size )
index_file . append_meta_data ( " total_size " , total_size )
index_file . write_index_file ( save_index_file )
index_file . write_index_file ( save_index_file )