[ckpt] Add async ckpt api (#6136)

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix
pull/6124/merge
Wang Binluo 1 week ago committed by Hongxin Liu
parent d4a436051d
commit 8e08c27e19

@ -117,7 +117,7 @@ jobs:
cd TensorNVMe
conda install cmake
pip install -r requirements.txt
DISABLE_URING=1 pip install -v .
DISABLE_URING=1 pip install -v --no-cache-dir .
- name: Store TensorNVMe Cache
run: |

@ -325,6 +325,7 @@ class Booster:
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
use_async (bool, optional): whether to save the state_dict of model asynchronously. Default: False.
"""
self.checkpoint_io.save_model(
model,

@ -65,7 +65,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self,
model: GeminiDDP,
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
use_async: bool = False,
):
"""
Save sharded model to checkpoint but only on master process.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
@ -74,7 +81,10 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors)
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
"""
@ -112,6 +122,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
Save sharded model.
@ -130,27 +141,33 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
# Save shards of optimizer states.
is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
use_safetensors=use_safetensors,
)
if use_async:
super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async
)
# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path)
self.logger.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.",
ranks=[0],
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
use_safetensors=use_safetensors,
)
# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path)
self.logger.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.",
ranks=[0],
)
def load_sharded_model(
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
):

@ -54,7 +54,9 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
optimizer.load_state_dict(sharded_osd)
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
"""
Save model to checkpoint but only on master process.
"""
@ -82,6 +84,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
Save model to checkpoint but only on master process.

@ -176,10 +176,10 @@ class CheckpointIO(ABC):
if shard:
self.save_sharded_model(
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async=use_async
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async
)
else:
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
"""

@ -61,8 +61,8 @@ class GeneralCheckpointIO(CheckpointIO):
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
else:
else:
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)

@ -27,6 +27,8 @@ from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
StateDictSharder,
async_save_state_dict_shards,
create_pinned_state_dict,
gather_distributed_param,
get_model_base_filenames,
get_optimizer_base_filenames,
@ -177,6 +179,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
@ -194,6 +197,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
@ -219,24 +223,27 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
# When pipeline is used, each stage produces its own shard files and index files.
@ -251,7 +258,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
if use_async:
total_size, returned_state_dict, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_pp_format=True,
n_write_entries=191,
)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
@ -626,7 +642,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
"""
Save model state dict to a single file with given checkpointing path.
@ -635,6 +653,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
"""
if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
@ -651,7 +670,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
if self.tp_rank == 0:
save_state_dict(state_dict, checkpoint, use_safetensors)
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)
else:
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
state_dict_list = [None for _ in range(self.pp_size)]
@ -662,7 +684,18 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
complete_state_dict = dict()
for _state_dict in state_dict_list:
complete_state_dict.update(_state_dict)
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter
from colossalai.utils.safetensors import move_and_save
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
else:
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):
"""

@ -117,6 +117,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
@ -161,24 +162,27 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
if self.pp_size == 1 and self.ep_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
dist.barrier()
else:
@ -708,10 +712,20 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
use_async: bool = False,
):
state_dict = self.pre_save_model(model)
if dist.get_rank() == 0:
torch.save(state_dict, checkpoint)
if use_async:
super().save_unsharded_model(
model=model,
checkpoint=checkpoint,
gather_dtensor=gather_dtensor,
use_safetensors=use_safetensors,
use_async=use_async,
)
else:
torch.save(state_dict, checkpoint)
dist.barrier()
# Copied from colossalai.moe

@ -371,7 +371,11 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
# ======================================
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
def save_state_dict(
state_dict: dict,
checkpoint_file_path: str,
use_safetensors: bool,
) -> None:
"""
Save state dict to checkpoint.
@ -581,14 +585,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
if use_safetensors:
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import safe_open
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata["format"] != "pt":
raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
)
return safe_load_file(checkpoint_file)
else:
return torch.load(checkpoint_file, map_location=torch.device("cpu"))

@ -28,14 +28,12 @@ class PreparedData:
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor], List[str]]:
sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0]))
tensors = []
tensor_keys = []
metadata = {}
offset = 0
for name, tensor in sorted_data:
for name, tensor in data.items():
n = tensor.numel() * tensor.element_size()
tensor_info = TensorInfo(
dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n)

@ -63,10 +63,15 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
booster.save_model(
bert_model, pretrained_path, True, True, "", (model_size / 3), use_safetensors=use_safetensors
bert_model,
pretrained_path,
True,
True,
"",
(model_size / 3),
use_safetensors=use_safetensors,
)
dist.barrier()
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict())
@ -119,7 +124,12 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.save_model(
model,
model_ckpt_path,
shard=shard,
size_per_shard=size_per_shard,
)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
dist.barrier()

@ -26,9 +26,10 @@ from tests.kit.model_zoo import model_zoo
# only test 2 is fine
@clear_cache_before_run()
@parameterize("stage", [2])
@parameterize("shard", [True, False])
@parameterize("shard", [False, True])
@parameterize("offload", [False, True])
def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
@parameterize("use_async", [False, True])
def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, use_async: bool):
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload)
booster = Booster(plugin=plugin)
model = resnet18()
@ -41,13 +42,26 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.step()
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
if not shard and not use_async:
model_ckpt_path = f"{model_ckpt_path}.pt"
if not shard and use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
booster.save_model(
model,
model_ckpt_path,
shard=shard,
use_async=use_async,
)
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
booster.save_model(model, model_ckpt_path, shard=shard)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier()
new_model = resnet18()
@ -71,6 +85,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())
torch.cuda.empty_cache()

Loading…
Cancel
Save