|
|
|
@ -27,6 +27,8 @@ from .general_checkpoint_io import GeneralCheckpointIO
|
|
|
|
|
from .index_file import CheckpointIndexFile |
|
|
|
|
from .utils import ( |
|
|
|
|
StateDictSharder, |
|
|
|
|
async_save_state_dict_shards, |
|
|
|
|
create_pinned_state_dict, |
|
|
|
|
gather_distributed_param, |
|
|
|
|
get_model_base_filenames, |
|
|
|
|
get_optimizer_base_filenames, |
|
|
|
@ -177,6 +179,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
prefix: Optional[str] = None, |
|
|
|
|
size_per_shard: int = 1024, |
|
|
|
|
use_safetensors: bool = False, |
|
|
|
|
use_async: bool = False, |
|
|
|
|
) -> None: |
|
|
|
|
""" |
|
|
|
|
Save sharded model checkpoint under the given checkpointing path. |
|
|
|
@ -194,6 +197,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
prefix (str, optional): Perfix of file to save. Defaults to None. |
|
|
|
|
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. |
|
|
|
|
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. |
|
|
|
|
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!" |
|
|
|
@ -219,24 +223,27 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
|
|
|
|
|
if self.pp_size == 1: |
|
|
|
|
# When pipeline is not used, save the model shards as in general checkpointIO |
|
|
|
|
total_size = save_state_dict_shards( |
|
|
|
|
sharded_state_dict=state_dict_shard, |
|
|
|
|
checkpoint=checkpoint, |
|
|
|
|
index_file=index_file, |
|
|
|
|
base_filename=weights_name, |
|
|
|
|
is_master=control_saving, |
|
|
|
|
use_safetensors=use_safetensors, |
|
|
|
|
) |
|
|
|
|
if control_saving: |
|
|
|
|
index_file.append_meta_data("total_size", total_size) |
|
|
|
|
index_file.write_index_file(save_index_file) |
|
|
|
|
save_config_file(model, checkpoint) |
|
|
|
|
if self.verbose and self.coordinator.is_master(): |
|
|
|
|
logging.info( |
|
|
|
|
f"The model is split into checkpoint shards. " |
|
|
|
|
f"You can find where each parameters has been saved in the " |
|
|
|
|
f"index located at {save_index_file}." |
|
|
|
|
) |
|
|
|
|
if use_async: |
|
|
|
|
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async) |
|
|
|
|
else: |
|
|
|
|
total_size = save_state_dict_shards( |
|
|
|
|
sharded_state_dict=state_dict_shard, |
|
|
|
|
checkpoint=checkpoint, |
|
|
|
|
index_file=index_file, |
|
|
|
|
base_filename=weights_name, |
|
|
|
|
is_master=control_saving, |
|
|
|
|
use_safetensors=use_safetensors, |
|
|
|
|
) |
|
|
|
|
if control_saving: |
|
|
|
|
index_file.append_meta_data("total_size", total_size) |
|
|
|
|
index_file.write_index_file(save_index_file) |
|
|
|
|
save_config_file(model, checkpoint) |
|
|
|
|
if self.verbose and self.coordinator.is_master(): |
|
|
|
|
logging.info( |
|
|
|
|
f"The model is split into checkpoint shards. " |
|
|
|
|
f"You can find where each parameters has been saved in the " |
|
|
|
|
f"index located at {save_index_file}." |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
# When pipeline is used, each stage produces its own shard files and index files. |
|
|
|
@ -251,7 +258,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") |
|
|
|
|
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") |
|
|
|
|
save_index_file = os.path.join("tmp_index_files", save_index_file) |
|
|
|
|
|
|
|
|
|
if use_async: |
|
|
|
|
total_size, returned_state_dict, writers = async_save_state_dict_shards( |
|
|
|
|
sharded_state_dict=state_dict_shard, |
|
|
|
|
checkpoint=checkpoint, |
|
|
|
|
index_file=index_file, |
|
|
|
|
base_filename=weights_name, |
|
|
|
|
is_master=control_saving, |
|
|
|
|
use_pp_format=True, |
|
|
|
|
n_write_entries=191, |
|
|
|
|
) |
|
|
|
|
total_size = save_state_dict_shards( |
|
|
|
|
sharded_state_dict=state_dict_shard, |
|
|
|
|
checkpoint=checkpoint, |
|
|
|
@ -626,7 +642,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
if self.verbose and self.coordinator.is_master(): |
|
|
|
|
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") |
|
|
|
|
|
|
|
|
|
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): |
|
|
|
|
def save_unsharded_model( |
|
|
|
|
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False |
|
|
|
|
): |
|
|
|
|
""" |
|
|
|
|
Save model state dict to a single file with given checkpointing path. |
|
|
|
|
|
|
|
|
@ -635,6 +653,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path. |
|
|
|
|
gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True. |
|
|
|
|
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. |
|
|
|
|
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False. |
|
|
|
|
""" |
|
|
|
|
if self.coordinator.is_master(): |
|
|
|
|
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") |
|
|
|
@ -651,7 +670,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
if self.pp_size == 1: |
|
|
|
|
# When pipeline is not used, let master rank directly save the collected state_dict. |
|
|
|
|
if self.tp_rank == 0: |
|
|
|
|
save_state_dict(state_dict, checkpoint, use_safetensors) |
|
|
|
|
if use_async: |
|
|
|
|
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async) |
|
|
|
|
else: |
|
|
|
|
save_state_dict(state_dict, checkpoint, use_safetensors) |
|
|
|
|
else: |
|
|
|
|
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. |
|
|
|
|
state_dict_list = [None for _ in range(self.pp_size)] |
|
|
|
@ -662,7 +684,18 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
complete_state_dict = dict() |
|
|
|
|
for _state_dict in state_dict_list: |
|
|
|
|
complete_state_dict.update(_state_dict) |
|
|
|
|
save_state_dict(complete_state_dict, checkpoint, use_safetensors) |
|
|
|
|
if use_async: |
|
|
|
|
from tensornvme.async_file_io import AsyncFileWriter |
|
|
|
|
|
|
|
|
|
from colossalai.utils.safetensors import move_and_save |
|
|
|
|
|
|
|
|
|
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread") |
|
|
|
|
if id(model) not in self.pinned_state_dicts: |
|
|
|
|
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) |
|
|
|
|
self.async_writers.append(writer) |
|
|
|
|
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)]) |
|
|
|
|
else: |
|
|
|
|
save_state_dict(complete_state_dict, checkpoint, use_safetensors) |
|
|
|
|
|
|
|
|
|
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False): |
|
|
|
|
""" |
|
|
|
|