From 8e08c27e19d3f8dcfbae36dffcad0591c0cf9cfc Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Fri, 15 Nov 2024 18:19:16 +0800 Subject: [PATCH] [ckpt] Add async ckpt api (#6136) * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix --- .github/workflows/build_on_pr.yml | 2 +- colossalai/booster/booster.py | 1 + colossalai/booster/plugin/gemini_plugin.py | 57 +++++++++----- .../booster/plugin/torch_fsdp_plugin.py | 5 +- .../checkpoint_io/checkpoint_io_base.py | 4 +- .../checkpoint_io/general_checkpoint_io.py | 2 +- .../hybrid_parallel_checkpoint_io.py | 77 +++++++++++++------ colossalai/checkpoint_io/moe_checkpoint.py | 52 ++++++++----- colossalai/checkpoint_io/utils.py | 13 ++-- colossalai/utils/safetensors.py | 4 +- .../test_gemini_checkpoint_io.py | 16 +++- .../test_low_level_zero_checkpoint_io.py | 23 +++++- 12 files changed, 172 insertions(+), 84 deletions(-) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index ceb33c9ac..8d96ca1b9 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -117,7 +117,7 @@ jobs: cd TensorNVMe conda install cmake pip install -r requirements.txt - DISABLE_URING=1 pip install -v . + DISABLE_URING=1 pip install -v --no-cache-dir . - name: Store TensorNVMe Cache run: | diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 2518b2511..ad4047ee2 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -325,6 +325,7 @@ class Booster: names to compose the keys in state_dict. Defaults to None. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved. + use_async (bool, optional): whether to save the state_dict of model asynchronously. Default: False. """ self.checkpoint_io.save_model( model, diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 4c8258113..35c51da01 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -65,7 +65,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO): self.coordinator = DistCoordinator() self.logger = get_dist_logger() - def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + def save_unsharded_model( + self, + model: GeminiDDP, + checkpoint: str, + gather_dtensor: bool, + use_safetensors: bool, + use_async: bool = False, + ): """ Save sharded model to checkpoint but only on master process. The model should be unwrapped in self.load_model via ModelWrapper.unwrap. @@ -74,7 +81,10 @@ class GeminiCheckpointIO(GeneralCheckpointIO): assert isinstance(model, GeminiDDP), "Please boost the model before saving!" state_dict = model.state_dict(only_rank_0=True) if self.coordinator.is_master(): - save_state_dict(state_dict, checkpoint, use_safetensors) + if use_async: + super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async) + else: + save_state_dict(state_dict, checkpoint, use_safetensors) def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True): """ @@ -112,6 +122,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): prefix: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False, + use_async: bool = False, ): """ Save sharded model. @@ -130,27 +141,33 @@ class GeminiCheckpointIO(GeneralCheckpointIO): # Save shards of optimizer states. is_master = self.coordinator.is_master() - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint_path, - index_file=index_file, - base_filename=weights_name, - is_master=is_master, - use_safetensors=use_safetensors, - ) + if use_async: + super().save_sharded_model( + model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async + ) - # only save the index file on the master rank - if self.coordinator.is_master(): - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - save_config_file(model.unwrap(), checkpoint_path) - self.logger.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.", - ranks=[0], + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=is_master, + use_safetensors=use_safetensors, ) + # only save the index file on the master rank + if self.coordinator.is_master(): + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model.unwrap(), checkpoint_path) + self.logger.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.", + ranks=[0], + ) + def load_sharded_model( self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False ): diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 23a35bbcb..d309370dd 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -54,7 +54,9 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model) optimizer.load_state_dict(sharded_osd) - def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + def save_unsharded_model( + self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False + ): """ Save model to checkpoint but only on master process. """ @@ -82,6 +84,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): prefix: Optional[str] = None, size_per_shard: int = 1024, use_safetensors: bool = False, + use_async: bool = False, ): """ Save model to checkpoint but only on master process. diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 4d752f3e6..6e4681f0e 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -176,10 +176,10 @@ class CheckpointIO(ABC): if shard: self.save_sharded_model( - model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async=use_async + model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async ) else: - self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async) + self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async) def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024): """ diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index a4866e64c..580be91ca 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -61,8 +61,8 @@ class GeneralCheckpointIO(CheckpointIO): self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) self.async_writers.append(writer) move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)]) - else: + else: # save the checkpoint save_state_dict(state_dict, checkpoint, use_safetensors) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 79bb33dca..49d4f35f9 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -27,6 +27,8 @@ from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile from .utils import ( StateDictSharder, + async_save_state_dict_shards, + create_pinned_state_dict, gather_distributed_param, get_model_base_filenames, get_optimizer_base_filenames, @@ -177,6 +179,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): prefix: Optional[str] = None, size_per_shard: int = 1024, use_safetensors: bool = False, + use_async: bool = False, ) -> None: """ Save sharded model checkpoint under the given checkpointing path. @@ -194,6 +197,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): prefix (str, optional): Perfix of file to save. Defaults to None. size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False. """ assert isinstance(model, ModelWrapper), "Please boost the model before saving!" @@ -219,24 +223,27 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): if self.pp_size == 1: # When pipeline is not used, save the model shards as in general checkpointIO - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=control_saving, - use_safetensors=use_safetensors, - ) - if control_saving: - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - save_config_file(model, checkpoint) - if self.verbose and self.coordinator.is_master(): - logging.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) + if use_async: + super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async) + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) else: # When pipeline is used, each stage produces its own shard files and index files. @@ -251,7 +258,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) - + if use_async: + total_size, returned_state_dict, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_pp_format=True, + n_write_entries=191, + ) total_size = save_state_dict_shards( sharded_state_dict=state_dict_shard, checkpoint=checkpoint, @@ -626,7 +642,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): if self.verbose and self.coordinator.is_master(): logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") - def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + def save_unsharded_model( + self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False + ): """ Save model state dict to a single file with given checkpointing path. @@ -635,6 +653,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path. gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True. use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False. """ if self.coordinator.is_master(): logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") @@ -651,7 +670,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): if self.pp_size == 1: # When pipeline is not used, let master rank directly save the collected state_dict. if self.tp_rank == 0: - save_state_dict(state_dict, checkpoint, use_safetensors) + if use_async: + super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async) + else: + save_state_dict(state_dict, checkpoint, use_safetensors) else: # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. state_dict_list = [None for _ in range(self.pp_size)] @@ -662,7 +684,18 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): complete_state_dict = dict() for _state_dict in state_dict_list: complete_state_dict.update(_state_dict) - save_state_dict(complete_state_dict, checkpoint, use_safetensors) + if use_async: + from tensornvme.async_file_io import AsyncFileWriter + + from colossalai.utils.safetensors import move_and_save + + writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread") + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + self.async_writers.append(writer) + move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)]) + else: + save_state_dict(complete_state_dict, checkpoint, use_safetensors) def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False): """ diff --git a/colossalai/checkpoint_io/moe_checkpoint.py b/colossalai/checkpoint_io/moe_checkpoint.py index 9181956b7..4cb0f300f 100644 --- a/colossalai/checkpoint_io/moe_checkpoint.py +++ b/colossalai/checkpoint_io/moe_checkpoint.py @@ -117,6 +117,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO): prefix: Optional[str] = None, size_per_shard: int = 1024, use_safetensors: bool = False, + use_async: bool = False, ) -> None: """ Save sharded model checkpoint under the given checkpointing path. @@ -161,24 +162,27 @@ class MoECheckpointIO(HybridParallelCheckpointIO): if self.pp_size == 1 and self.ep_size == 1: # When pipeline is not used, save the model shards as in general checkpointIO - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=control_saving, - use_safetensors=use_safetensors, - ) - if control_saving: - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - save_config_file(model, checkpoint) - if self.verbose and self.coordinator.is_master(): - logging.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) + if use_async: + super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async) + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) dist.barrier() else: @@ -708,10 +712,20 @@ class MoECheckpointIO(HybridParallelCheckpointIO): checkpoint: str, gather_dtensor: bool, use_safetensors: bool, + use_async: bool = False, ): state_dict = self.pre_save_model(model) if dist.get_rank() == 0: - torch.save(state_dict, checkpoint) + if use_async: + super().save_unsharded_model( + model=model, + checkpoint=checkpoint, + gather_dtensor=gather_dtensor, + use_safetensors=use_safetensors, + use_async=use_async, + ) + else: + torch.save(state_dict, checkpoint) dist.barrier() # Copied from colossalai.moe diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 6d539cce6..8487064f5 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -371,7 +371,11 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> # ====================================== -def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: +def save_state_dict( + state_dict: dict, + checkpoint_file_path: str, + use_safetensors: bool, +) -> None: """ Save state dict to checkpoint. @@ -581,14 +585,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): raise Exception("load the model using `safetensors`, but no file endwith .safetensors") if use_safetensors: from safetensors.torch import load_file as safe_load_file - from safetensors.torch import safe_open - with safe_open(checkpoint_file, framework="pt") as f: - metadata = f.metadata() - if metadata["format"] != "pt": - raise NotImplementedError( - f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet." - ) return safe_load_file(checkpoint_file) else: return torch.load(checkpoint_file, map_location=torch.device("cpu")) diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index bf8decd0f..035954114 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -28,14 +28,12 @@ class PreparedData: def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor], List[str]]: - sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0])) tensors = [] tensor_keys = [] metadata = {} offset = 0 - - for name, tensor in sorted_data: + for name, tensor in data.items(): n = tensor.numel() * tensor.element_size() tensor_info = TensorInfo( dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n) diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index b133be948..8bee8fe97 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -63,10 +63,15 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 booster.save_model( - bert_model, pretrained_path, True, True, "", (model_size / 3), use_safetensors=use_safetensors + bert_model, + pretrained_path, + True, + True, + "", + (model_size / 3), + use_safetensors=use_safetensors, ) dist.barrier() - new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path) check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict()) @@ -119,7 +124,12 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" - booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + booster.save_model( + model, + model_ckpt_path, + shard=shard, + size_per_shard=size_per_shard, + ) booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) dist.barrier() diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index a8e05a25a..5e3cc2bdc 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -26,9 +26,10 @@ from tests.kit.model_zoo import model_zoo # only test 2 is fine @clear_cache_before_run() @parameterize("stage", [2]) -@parameterize("shard", [True, False]) +@parameterize("shard", [False, True]) @parameterize("offload", [False, True]) -def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): +@parameterize("use_async", [False, True]) +def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, use_async: bool): plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload) booster = Booster(plugin=plugin) model = resnet18() @@ -41,13 +42,26 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): loss = criterion(output) booster.backward(loss, optimizer) optimizer.step() + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" + if not shard and not use_async: + model_ckpt_path = f"{model_ckpt_path}.pt" + if not shard and use_async: + model_ckpt_path = f"{model_ckpt_path}.safetensors" + booster.save_model( + model, + model_ckpt_path, + shard=shard, + use_async=use_async, + ) + # lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here - booster.save_model(model, model_ckpt_path, shard=shard) booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) - + booster.checkpoint_io._sync_d2h() + booster.checkpoint_io._sync_io() dist.barrier() new_model = resnet18() @@ -71,6 +85,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): booster.load_optimizer(new_optimizer, optimizer_ckpt_path) check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict()) + torch.cuda.empty_cache()