|
|
|
@ -27,6 +27,8 @@ from .general_checkpoint_io import GeneralCheckpointIO
|
|
|
|
|
from .index_file import CheckpointIndexFile
|
|
|
|
|
from .utils import (
|
|
|
|
|
StateDictSharder,
|
|
|
|
|
async_save_state_dict_shards,
|
|
|
|
|
create_pinned_state_dict,
|
|
|
|
|
gather_distributed_param,
|
|
|
|
|
get_model_base_filenames,
|
|
|
|
|
get_optimizer_base_filenames,
|
|
|
|
@ -177,6 +179,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
prefix: Optional[str] = None,
|
|
|
|
|
size_per_shard: int = 1024,
|
|
|
|
|
use_safetensors: bool = False,
|
|
|
|
|
use_async: bool = False,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Save sharded model checkpoint under the given checkpointing path.
|
|
|
|
@ -194,6 +197,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
prefix (str, optional): Perfix of file to save. Defaults to None.
|
|
|
|
|
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
|
|
|
|
|
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
|
|
|
|
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
|
|
|
@ -219,24 +223,27 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
|
|
|
|
|
if self.pp_size == 1:
|
|
|
|
|
# When pipeline is not used, save the model shards as in general checkpointIO
|
|
|
|
|
total_size = save_state_dict_shards(
|
|
|
|
|
sharded_state_dict=state_dict_shard,
|
|
|
|
|
checkpoint=checkpoint,
|
|
|
|
|
index_file=index_file,
|
|
|
|
|
base_filename=weights_name,
|
|
|
|
|
is_master=control_saving,
|
|
|
|
|
use_safetensors=use_safetensors,
|
|
|
|
|
)
|
|
|
|
|
if control_saving:
|
|
|
|
|
index_file.append_meta_data("total_size", total_size)
|
|
|
|
|
index_file.write_index_file(save_index_file)
|
|
|
|
|
save_config_file(model, checkpoint)
|
|
|
|
|
if self.verbose and self.coordinator.is_master():
|
|
|
|
|
logging.info(
|
|
|
|
|
f"The model is split into checkpoint shards. "
|
|
|
|
|
f"You can find where each parameters has been saved in the "
|
|
|
|
|
f"index located at {save_index_file}."
|
|
|
|
|
)
|
|
|
|
|
if use_async:
|
|
|
|
|
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
|
|
|
|
|
else:
|
|
|
|
|
total_size = save_state_dict_shards(
|
|
|
|
|
sharded_state_dict=state_dict_shard,
|
|
|
|
|
checkpoint=checkpoint,
|
|
|
|
|
index_file=index_file,
|
|
|
|
|
base_filename=weights_name,
|
|
|
|
|
is_master=control_saving,
|
|
|
|
|
use_safetensors=use_safetensors,
|
|
|
|
|
)
|
|
|
|
|
if control_saving:
|
|
|
|
|
index_file.append_meta_data("total_size", total_size)
|
|
|
|
|
index_file.write_index_file(save_index_file)
|
|
|
|
|
save_config_file(model, checkpoint)
|
|
|
|
|
if self.verbose and self.coordinator.is_master():
|
|
|
|
|
logging.info(
|
|
|
|
|
f"The model is split into checkpoint shards. "
|
|
|
|
|
f"You can find where each parameters has been saved in the "
|
|
|
|
|
f"index located at {save_index_file}."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
# When pipeline is used, each stage produces its own shard files and index files.
|
|
|
|
@ -251,7 +258,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
|
|
|
|
|
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
|
|
|
|
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
|
|
|
|
|
|
|
|
|
if use_async:
|
|
|
|
|
total_size, returned_state_dict, writers = async_save_state_dict_shards(
|
|
|
|
|
sharded_state_dict=state_dict_shard,
|
|
|
|
|
checkpoint=checkpoint,
|
|
|
|
|
index_file=index_file,
|
|
|
|
|
base_filename=weights_name,
|
|
|
|
|
is_master=control_saving,
|
|
|
|
|
use_pp_format=True,
|
|
|
|
|
n_write_entries=191,
|
|
|
|
|
)
|
|
|
|
|
total_size = save_state_dict_shards(
|
|
|
|
|
sharded_state_dict=state_dict_shard,
|
|
|
|
|
checkpoint=checkpoint,
|
|
|
|
@ -626,7 +642,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
if self.verbose and self.coordinator.is_master():
|
|
|
|
|
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
|
|
|
|
|
|
|
|
|
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
|
|
|
|
def save_unsharded_model(
|
|
|
|
|
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Save model state dict to a single file with given checkpointing path.
|
|
|
|
|
|
|
|
|
@ -635,6 +653,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
|
|
|
|
|
gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
|
|
|
|
|
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
|
|
|
|
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
|
|
|
|
|
"""
|
|
|
|
|
if self.coordinator.is_master():
|
|
|
|
|
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
|
|
|
@ -651,7 +670,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
if self.pp_size == 1:
|
|
|
|
|
# When pipeline is not used, let master rank directly save the collected state_dict.
|
|
|
|
|
if self.tp_rank == 0:
|
|
|
|
|
save_state_dict(state_dict, checkpoint, use_safetensors)
|
|
|
|
|
if use_async:
|
|
|
|
|
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
|
|
|
|
|
else:
|
|
|
|
|
save_state_dict(state_dict, checkpoint, use_safetensors)
|
|
|
|
|
else:
|
|
|
|
|
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
|
|
|
|
|
state_dict_list = [None for _ in range(self.pp_size)]
|
|
|
|
@ -662,7 +684,18 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
complete_state_dict = dict()
|
|
|
|
|
for _state_dict in state_dict_list:
|
|
|
|
|
complete_state_dict.update(_state_dict)
|
|
|
|
|
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
|
|
|
|
if use_async:
|
|
|
|
|
from tensornvme.async_file_io import AsyncFileWriter
|
|
|
|
|
|
|
|
|
|
from colossalai.utils.safetensors import move_and_save
|
|
|
|
|
|
|
|
|
|
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
|
|
|
|
|
if id(model) not in self.pinned_state_dicts:
|
|
|
|
|
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
|
|
|
|
self.async_writers.append(writer)
|
|
|
|
|
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
|
|
|
|
|
else:
|
|
|
|
|
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
|
|
|
|
|
|
|
|
|
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):
|
|
|
|
|
"""
|
|
|
|
|