[checkpointio] support async model save (#6131)

* [checkpointio] support async model save

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/6124/merge
Hongxin Liu 1 week ago
parent 5a03d2696d
commit d4a436051d

@ -310,6 +310,7 @@ class Booster:
prefix: Optional[str] = None, prefix: Optional[str] = None,
size_per_shard: int = 1024, size_per_shard: int = 1024,
use_safetensors: bool = False, use_safetensors: bool = False,
use_async: bool = False,
) -> None: ) -> None:
"""Save model to checkpoint. """Save model to checkpoint.
@ -333,6 +334,7 @@ class Booster:
prefix=prefix, prefix=prefix,
size_per_shard=size_per_shard, size_per_shard=size_per_shard,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
use_async=use_async,
) )
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None: def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:

@ -259,10 +259,12 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module) super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params() model.update_master_params()
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather() model._force_wait_all_gather()
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
def save_sharded_model( def save_sharded_model(
self, self,
@ -272,11 +274,12 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
prefix: Optional[str] = None, prefix: Optional[str] = None,
max_shard_size: int = 1024, max_shard_size: int = 1024,
use_safetensors: bool = False, use_safetensors: bool = False,
use_async: bool = False,
): ):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather() model._force_wait_all_gather()
return super().save_sharded_model( return super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async=use_async
) )
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):

@ -33,13 +33,17 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
assert isinstance(model, ModelWrapper), "Please boost the model before loading!" assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict) super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
""" """
Save model to checkpoint but only on master process. Save model to checkpoint but only on master process.
""" """
assert isinstance(model, ModelWrapper), "Please boost the model before saving!" assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master(): if self.coordinator.is_master():
super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors) super().save_unsharded_model(
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
)
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
""" """
@ -71,6 +75,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
prefix: Optional[str] = None, prefix: Optional[str] = None,
max_shard_size: int = 1024, max_shard_size: int = 1024,
use_safetensors: bool = False, use_safetensors: bool = False,
use_async: bool = False,
): ):
""" """
Save model to checkpoint but only on master process. Save model to checkpoint but only on master process.
@ -78,7 +83,13 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
assert isinstance(model, ModelWrapper), "Please boost the model before saving!" assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master(): if self.coordinator.is_master():
super().save_sharded_model( super().save_sharded_model(
model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors model.unwrap(),
checkpoint_path,
gather_dtensor,
prefix,
max_shard_size,
use_safetensors,
use_async=use_async,
) )
def load_sharded_model( def load_sharded_model(

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Dict, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -8,6 +8,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.interface import ModelWrapper from colossalai.interface import ModelWrapper
from colossalai.logging import get_dist_logger
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
@ -58,9 +59,34 @@ class CheckpointIO(ABC):
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt') >>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
""" """
N_WRITE_ENTRIES: int = 32
# ====================================== # ======================================
# Public methods # Public methods
# ====================================== # ======================================
def __init__(self):
super().__init__()
self.pinned_state_dicts: Dict[int, dict] = {}
self.async_writers = []
def _sync_io(self):
for writer in self.async_writers:
writer.synchronize()
writer.fp.close()
self.async_writers.clear()
def _sync_d2h(self):
for writer in self.async_writers:
writer.sync_before_step()
def synchronize(self):
"""This method must be called before updating the model weights."""
self._sync_d2h()
def __del__(self):
self._sync_d2h()
self._sync_io()
def load_model( def load_model(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
) -> Union[nn.Module, ModelWrapper]: ) -> Union[nn.Module, ModelWrapper]:
@ -111,6 +137,7 @@ class CheckpointIO(ABC):
prefix: str = None, prefix: str = None,
size_per_shard: int = 1024, size_per_shard: int = 1024,
use_safetensors: bool = False, use_safetensors: bool = False,
use_async: bool = False,
): ):
""" """
Save model to checkpoint. Save model to checkpoint.
@ -138,11 +165,21 @@ class CheckpointIO(ABC):
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
""" """
self._sync_io()
if use_async and not use_safetensors:
logger = get_dist_logger()
logger.warning(
"Async save is only supported when use_safetensors is set to True. "
"Setting use_safetensors to True for async save."
)
use_safetensors = True
if shard: if shard:
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors) self.save_sharded_model(
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async=use_async
)
else: else:
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024): def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
""" """
@ -234,6 +271,7 @@ class CheckpointIO(ABC):
prefix: Optional[str], prefix: Optional[str],
size_per_shard: int, size_per_shard: int,
use_safetensors: bool, use_safetensors: bool,
use_async: bool = False,
): ):
""" """
Save model to sharded checkpoint. Save model to sharded checkpoint.
@ -248,7 +286,9 @@ class CheckpointIO(ABC):
""" """
@abstractmethod @abstractmethod
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): def save_unsharded_model(
self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
""" """
Save model to unsharded checkpoint. Save model to unsharded checkpoint.

@ -8,9 +8,13 @@ from typing import Optional
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.utils.safetensors import move_and_save
from .checkpoint_io_base import CheckpointIO from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile from .index_file import CheckpointIndexFile
from .utils import ( from .utils import (
async_save_state_dict_shards,
create_pinned_state_dict,
get_model_base_filenames, get_model_base_filenames,
get_optimizer_base_filenames, get_optimizer_base_filenames,
is_safetensors_available, is_safetensors_available,
@ -40,13 +44,25 @@ class GeneralCheckpointIO(CheckpointIO):
checkpoint = load_state_dict(checkpoint) checkpoint = load_state_dict(checkpoint)
model.load_state_dict(checkpoint, strict=strict) model.load_state_dict(checkpoint, strict=strict)
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): def save_unsharded_model(
self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
state_dict = model.state_dict() state_dict = model.state_dict()
# TODO(FrankLeeeee): add support for gather_dtensor # TODO(FrankLeeeee): add support for gather_dtensor
if gather_dtensor: if gather_dtensor:
pass pass
if use_async:
from tensornvme.async_file_io import AsyncFileWriter
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
else:
# save the checkpoint # save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors) save_state_dict(state_dict, checkpoint, use_safetensors)
@ -151,6 +167,7 @@ class GeneralCheckpointIO(CheckpointIO):
prefix: Optional[str] = None, prefix: Optional[str] = None,
max_shard_size: int = 1024, max_shard_size: int = 1024,
use_safetensors: bool = False, use_safetensors: bool = False,
use_async: bool = False,
): ):
""" """
implement this method as it can be supported by Huggingface model, implement this method as it can be supported by Huggingface model,
@ -168,6 +185,20 @@ class GeneralCheckpointIO(CheckpointIO):
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint_path) index_file = CheckpointIndexFile(checkpoint_path)
if use_async:
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
total_size, new_pinned_state_dict, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
pinned_state_dict=pinned_state_dict,
n_write_entries=self.N_WRITE_ENTRIES,
)
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
# Save shards of optimizer states. # Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior. # In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards( total_size = save_state_dict_shards(

@ -5,7 +5,7 @@ from collections import abc as container_abcs
from collections import defaultdict from collections import defaultdict
from itertools import chain from itertools import chain
from pathlib import Path from pathlib import Path
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -19,6 +19,7 @@ from colossalai.tensor.d_tensor import (
to_global, to_global,
to_global_for_customized_distributed_tensor, to_global_for_customized_distributed_tensor,
) )
from colossalai.utils.safetensors import move_and_save
SAFE_WEIGHTS_NAME = "model.safetensors" SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_NAME = "pytorch_model.bin"
@ -263,6 +264,71 @@ def save_state_dict_shards(
return total_size return total_size
def async_save_state_dict_shards(
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
n_write_entries: int,
use_pp_format: bool = False,
) -> Tuple[int, Dict[str, torch.Tensor], list]:
"""
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
checkpoint (str): The path of checkpoint directory as string.
index_file (CheckpointIndexFile): The index file object to be updated.
base_filename (str): Decides the prefix of filenames of shards.
is_master (bool): Whether current rank is main process.
use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
Returns:
int: the total size of shards
"""
from tensornvme.async_file_io import AsyncFileWriter
total_size = 0
shard_filenames = []
if pinned_state_dict is None:
returned_state_dict = {}
else:
returned_state_dict = pinned_state_dict
writers = []
for idx, shard_pair in enumerate(sharded_state_dict):
shard, current_size = shard_pair
# Just loop over the sharder and gather to other ranks if not master
if not is_master:
del shard
continue
shard_file = get_shard_filename(base_filename, idx)
total_size = total_size + current_size
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
writer = AsyncFileWriter(open(checkpoint_file_path, "wb"), n_write_entries, backend="pthread")
writers.append(writer)
if pinned_state_dict is not None:
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()}
else:
sub_pinned_state_dict = create_pinned_state_dict(shard)
returned_state_dict.update(sub_pinned_state_dict)
# Only save on master rank.
move_and_save(writer, shard, sub_pinned_state_dict)
shard_filenames.append(shard_file)
del shard
# Clean folder, deleted unneeded files.
clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
return total_size, returned_state_dict, writers
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
""" """
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
@ -799,3 +865,10 @@ def get_shard_filename(weights_name: str, idx: int):
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin") shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors") shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors")
return shard_file return shard_file
def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]):
pin_mem = dict()
for name, tensor in state_dict.items():
pin_mem[name] = torch.empty_like(tensor, pin_memory=True, device="cpu")
return pin_mem

@ -1,7 +1,7 @@
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214 # a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
import json import json
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
from safetensors.torch import _TYPES from safetensors.torch import _TYPES
@ -27,10 +27,11 @@ class PreparedData:
offset: int offset: int
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor]]: def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor], List[str]]:
sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0])) sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0]))
tensors = [] tensors = []
tensor_keys = []
metadata = {} metadata = {}
offset = 0 offset = 0
@ -42,6 +43,7 @@ def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Ten
offset += n offset += n
metadata[name] = asdict(tensor_info) metadata[name] = asdict(tensor_info)
tensors.append(tensor) tensors.append(tensor)
tensor_keys.append(name)
metadata_buf = json.dumps(metadata).encode("utf-8") metadata_buf = json.dumps(metadata).encode("utf-8")
@ -50,11 +52,11 @@ def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Ten
n = len(metadata_buf) n = len(metadata_buf)
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors, tensor_keys
def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None: def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
prepared_data, tensors = prepare(state_dict) prepared_data, tensors, _ = prepare(state_dict)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
f_writer.write(n.to_bytes(8, byteorder="little")) f_writer.write(n.to_bytes(8, byteorder="little"))
@ -62,3 +64,22 @@ def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None
for tensor in tensors: for tensor in tensors:
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset) f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
def move_and_save(
f_writer: AsyncFileWriter,
state_dict: Dict[str, torch.Tensor],
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
prepared_data, _, tensor_keys = prepare(state_dict)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
f_writer.write(n.to_bytes(8, byteorder="little"))
f_writer.write(header_bytes)
f_writer.register_h2d(len(tensor_keys))
for name in tensor_keys:
if state_dict_pinned:
f_writer.write_tensor(state_dict[name], state_dict_pinned[name])
else:
f_writer.write_tensor(state_dict[name])

Loading…
Cancel
Save