Browse Source

[ckpt] Add async ckpt api (#6136)

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix
pull/6124/merge
Wang Binluo 7 days ago committed by Hongxin Liu
parent
commit
8e08c27e19
  1. 2
      .github/workflows/build_on_pr.yml
  2. 1
      colossalai/booster/booster.py
  3. 19
      colossalai/booster/plugin/gemini_plugin.py
  4. 5
      colossalai/booster/plugin/torch_fsdp_plugin.py
  5. 4
      colossalai/checkpoint_io/checkpoint_io_base.py
  6. 2
      colossalai/checkpoint_io/general_checkpoint_io.py
  7. 37
      colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
  8. 14
      colossalai/checkpoint_io/moe_checkpoint.py
  9. 13
      colossalai/checkpoint_io/utils.py
  10. 4
      colossalai/utils/safetensors.py
  11. 16
      tests/test_checkpoint_io/test_gemini_checkpoint_io.py
  12. 23
      tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py

2
.github/workflows/build_on_pr.yml

@ -117,7 +117,7 @@ jobs:
cd TensorNVMe cd TensorNVMe
conda install cmake conda install cmake
pip install -r requirements.txt pip install -r requirements.txt
DISABLE_URING=1 pip install -v . DISABLE_URING=1 pip install -v --no-cache-dir .
- name: Store TensorNVMe Cache - name: Store TensorNVMe Cache
run: | run: |

1
colossalai/booster/booster.py

@ -325,6 +325,7 @@ class Booster:
names to compose the keys in state_dict. Defaults to None. names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved. use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
use_async (bool, optional): whether to save the state_dict of model asynchronously. Default: False.
""" """
self.checkpoint_io.save_model( self.checkpoint_io.save_model(
model, model,

19
colossalai/booster/plugin/gemini_plugin.py

@ -65,7 +65,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
self.logger = get_dist_logger() self.logger = get_dist_logger()
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): def save_unsharded_model(
self,
model: GeminiDDP,
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
use_async: bool = False,
):
""" """
Save sharded model to checkpoint but only on master process. Save sharded model to checkpoint but only on master process.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap. The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
@ -74,6 +81,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
assert isinstance(model, GeminiDDP), "Please boost the model before saving!" assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
state_dict = model.state_dict(only_rank_0=True) state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master(): if self.coordinator.is_master():
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
else:
save_state_dict(state_dict, checkpoint, use_safetensors) save_state_dict(state_dict, checkpoint, use_safetensors)
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True): def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
@ -112,6 +122,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
prefix: Optional[str] = None, prefix: Optional[str] = None,
max_shard_size: int = 1024, max_shard_size: int = 1024,
use_safetensors: bool = False, use_safetensors: bool = False,
use_async: bool = False,
): ):
""" """
Save sharded model. Save sharded model.
@ -130,6 +141,12 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
# Save shards of optimizer states. # Save shards of optimizer states.
is_master = self.coordinator.is_master() is_master = self.coordinator.is_master()
if use_async:
super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async
)
else:
total_size = save_state_dict_shards( total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard, sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path, checkpoint=checkpoint_path,

5
colossalai/booster/plugin/torch_fsdp_plugin.py

@ -54,7 +54,9 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model) sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
optimizer.load_state_dict(sharded_osd) optimizer.load_state_dict(sharded_osd)
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
""" """
Save model to checkpoint but only on master process. Save model to checkpoint but only on master process.
""" """
@ -82,6 +84,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
prefix: Optional[str] = None, prefix: Optional[str] = None,
size_per_shard: int = 1024, size_per_shard: int = 1024,
use_safetensors: bool = False, use_safetensors: bool = False,
use_async: bool = False,
): ):
""" """
Save model to checkpoint but only on master process. Save model to checkpoint but only on master process.

4
colossalai/checkpoint_io/checkpoint_io_base.py

@ -176,10 +176,10 @@ class CheckpointIO(ABC):
if shard: if shard:
self.save_sharded_model( self.save_sharded_model(
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async=use_async model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async
) )
else: else:
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async) self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024): def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
""" """

2
colossalai/checkpoint_io/general_checkpoint_io.py

@ -61,8 +61,8 @@ class GeneralCheckpointIO(CheckpointIO):
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer) self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)]) move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
else:
else:
# save the checkpoint # save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors) save_state_dict(state_dict, checkpoint, use_safetensors)

37
colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

@ -27,6 +27,8 @@ from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile from .index_file import CheckpointIndexFile
from .utils import ( from .utils import (
StateDictSharder, StateDictSharder,
async_save_state_dict_shards,
create_pinned_state_dict,
gather_distributed_param, gather_distributed_param,
get_model_base_filenames, get_model_base_filenames,
get_optimizer_base_filenames, get_optimizer_base_filenames,
@ -177,6 +179,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
prefix: Optional[str] = None, prefix: Optional[str] = None,
size_per_shard: int = 1024, size_per_shard: int = 1024,
use_safetensors: bool = False, use_safetensors: bool = False,
use_async: bool = False,
) -> None: ) -> None:
""" """
Save sharded model checkpoint under the given checkpointing path. Save sharded model checkpoint under the given checkpointing path.
@ -194,6 +197,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
prefix (str, optional): Perfix of file to save. Defaults to None. prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
""" """
assert isinstance(model, ModelWrapper), "Please boost the model before saving!" assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
@ -219,6 +223,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.pp_size == 1: if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO # When pipeline is not used, save the model shards as in general checkpointIO
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
else:
total_size = save_state_dict_shards( total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard, sharded_state_dict=state_dict_shard,
checkpoint=checkpoint, checkpoint=checkpoint,
@ -251,7 +258,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file) save_index_file = os.path.join("tmp_index_files", save_index_file)
if use_async:
total_size, returned_state_dict, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_pp_format=True,
n_write_entries=191,
)
total_size = save_state_dict_shards( total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard, sharded_state_dict=state_dict_shard,
checkpoint=checkpoint, checkpoint=checkpoint,
@ -626,7 +642,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.verbose and self.coordinator.is_master(): if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
""" """
Save model state dict to a single file with given checkpointing path. Save model state dict to a single file with given checkpointing path.
@ -635,6 +653,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path. checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True. gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
""" """
if self.coordinator.is_master(): if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
@ -651,6 +670,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.pp_size == 1: if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict. # When pipeline is not used, let master rank directly save the collected state_dict.
if self.tp_rank == 0: if self.tp_rank == 0:
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
else:
save_state_dict(state_dict, checkpoint, use_safetensors) save_state_dict(state_dict, checkpoint, use_safetensors)
else: else:
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
@ -662,6 +684,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
complete_state_dict = dict() complete_state_dict = dict()
for _state_dict in state_dict_list: for _state_dict in state_dict_list:
complete_state_dict.update(_state_dict) complete_state_dict.update(_state_dict)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter
from colossalai.utils.safetensors import move_and_save
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
else:
save_state_dict(complete_state_dict, checkpoint, use_safetensors) save_state_dict(complete_state_dict, checkpoint, use_safetensors)
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False): def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):

14
colossalai/checkpoint_io/moe_checkpoint.py

@ -117,6 +117,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
prefix: Optional[str] = None, prefix: Optional[str] = None,
size_per_shard: int = 1024, size_per_shard: int = 1024,
use_safetensors: bool = False, use_safetensors: bool = False,
use_async: bool = False,
) -> None: ) -> None:
""" """
Save sharded model checkpoint under the given checkpointing path. Save sharded model checkpoint under the given checkpointing path.
@ -161,6 +162,9 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
if self.pp_size == 1 and self.ep_size == 1: if self.pp_size == 1 and self.ep_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO # When pipeline is not used, save the model shards as in general checkpointIO
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
else:
total_size = save_state_dict_shards( total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard, sharded_state_dict=state_dict_shard,
checkpoint=checkpoint, checkpoint=checkpoint,
@ -708,9 +712,19 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
checkpoint: str, checkpoint: str,
gather_dtensor: bool, gather_dtensor: bool,
use_safetensors: bool, use_safetensors: bool,
use_async: bool = False,
): ):
state_dict = self.pre_save_model(model) state_dict = self.pre_save_model(model)
if dist.get_rank() == 0: if dist.get_rank() == 0:
if use_async:
super().save_unsharded_model(
model=model,
checkpoint=checkpoint,
gather_dtensor=gather_dtensor,
use_safetensors=use_safetensors,
use_async=use_async,
)
else:
torch.save(state_dict, checkpoint) torch.save(state_dict, checkpoint)
dist.barrier() dist.barrier()

13
colossalai/checkpoint_io/utils.py

@ -371,7 +371,11 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
# ====================================== # ======================================
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: def save_state_dict(
state_dict: dict,
checkpoint_file_path: str,
use_safetensors: bool,
) -> None:
""" """
Save state dict to checkpoint. Save state dict to checkpoint.
@ -581,14 +585,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
raise Exception("load the model using `safetensors`, but no file endwith .safetensors") raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
if use_safetensors: if use_safetensors:
from safetensors.torch import load_file as safe_load_file from safetensors.torch import load_file as safe_load_file
from safetensors.torch import safe_open
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata["format"] != "pt":
raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
)
return safe_load_file(checkpoint_file) return safe_load_file(checkpoint_file)
else: else:
return torch.load(checkpoint_file, map_location=torch.device("cpu")) return torch.load(checkpoint_file, map_location=torch.device("cpu"))

4
colossalai/utils/safetensors.py

@ -28,14 +28,12 @@ class PreparedData:
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor], List[str]]: def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor], List[str]]:
sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0]))
tensors = [] tensors = []
tensor_keys = [] tensor_keys = []
metadata = {} metadata = {}
offset = 0 offset = 0
for name, tensor in data.items():
for name, tensor in sorted_data:
n = tensor.numel() * tensor.element_size() n = tensor.numel() * tensor.element_size()
tensor_info = TensorInfo( tensor_info = TensorInfo(
dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n) dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n)

16
tests/test_checkpoint_io/test_gemini_checkpoint_io.py

@ -63,10 +63,15 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
booster.save_model( booster.save_model(
bert_model, pretrained_path, True, True, "", (model_size / 3), use_safetensors=use_safetensors bert_model,
pretrained_path,
True,
True,
"",
(model_size / 3),
use_safetensors=use_safetensors,
) )
dist.barrier() dist.barrier()
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path) new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict()) check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict())
@ -119,7 +124,12 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
with shared_tempdir() as tempdir: with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model" model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer" optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) booster.save_model(
model,
model_ckpt_path,
shard=shard,
size_per_shard=size_per_shard,
)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
dist.barrier() dist.barrier()

23
tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py

@ -26,9 +26,10 @@ from tests.kit.model_zoo import model_zoo
# only test 2 is fine # only test 2 is fine
@clear_cache_before_run() @clear_cache_before_run()
@parameterize("stage", [2]) @parameterize("stage", [2])
@parameterize("shard", [True, False]) @parameterize("shard", [False, True])
@parameterize("offload", [False, True]) @parameterize("offload", [False, True])
def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): @parameterize("use_async", [False, True])
def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, use_async: bool):
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload) plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
model = resnet18() model = resnet18()
@ -41,13 +42,26 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
loss = criterion(output) loss = criterion(output)
booster.backward(loss, optimizer) booster.backward(loss, optimizer)
optimizer.step() optimizer.step()
with shared_tempdir() as tempdir: with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model" model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer" optimizer_ckpt_path = f"{tempdir}/optimizer"
if not shard and not use_async:
model_ckpt_path = f"{model_ckpt_path}.pt"
if not shard and use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
booster.save_model(
model,
model_ckpt_path,
shard=shard,
use_async=use_async,
)
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here # lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
booster.save_model(model, model_ckpt_path, shard=shard)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier() dist.barrier()
new_model = resnet18() new_model = resnet18()
@ -71,6 +85,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
booster.load_optimizer(new_optimizer, optimizer_ckpt_path) booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict()) check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())
torch.cuda.empty_cache() torch.cuda.empty_cache()

Loading…
Cancel
Save