diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 8047d90f7..2518b2511 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -310,6 +310,7 @@ class Booster: prefix: Optional[str] = None, size_per_shard: int = 1024, use_safetensors: bool = False, + use_async: bool = False, ) -> None: """Save model to checkpoint. @@ -333,6 +334,7 @@ class Booster: prefix=prefix, size_per_shard=size_per_shard, use_safetensors=use_safetensors, + use_async=use_async, ) def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None: diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index d5afa2ba8..d4eb1bbed 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -259,10 +259,12 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module) model.update_master_params() - def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + def save_unsharded_model( + self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False + ): assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" model._force_wait_all_gather() - return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) + return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async) def save_sharded_model( self, @@ -272,11 +274,12 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): prefix: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False, + use_async: bool = False, ): assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" model._force_wait_all_gather() return super().save_sharded_model( - model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors + model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async=use_async ) def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 156a4acf9..09830a2f9 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -33,13 +33,17 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): assert isinstance(model, ModelWrapper), "Please boost the model before loading!" super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict) - def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + def save_unsharded_model( + self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False + ): """ Save model to checkpoint but only on master process. """ assert isinstance(model, ModelWrapper), "Please boost the model before saving!" if self.coordinator.is_master(): - super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors) + super().save_unsharded_model( + model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async + ) def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): """ @@ -71,6 +75,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): prefix: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False, + use_async: bool = False, ): """ Save model to checkpoint but only on master process. @@ -78,7 +83,13 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): assert isinstance(model, ModelWrapper), "Please boost the model before saving!" if self.coordinator.is_master(): super().save_sharded_model( - model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors + model.unwrap(), + checkpoint_path, + gather_dtensor, + prefix, + max_shard_size, + use_safetensors, + use_async=use_async, ) def load_sharded_model( diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 949ba4d44..4d752f3e6 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Optional, Union +from typing import Dict, Optional, Union import torch import torch.nn as nn @@ -8,6 +8,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from colossalai.interface import ModelWrapper +from colossalai.logging import get_dist_logger from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file @@ -58,9 +59,34 @@ class CheckpointIO(ABC): >>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt') """ + N_WRITE_ENTRIES: int = 32 + # ====================================== # Public methods # ====================================== + def __init__(self): + super().__init__() + self.pinned_state_dicts: Dict[int, dict] = {} + self.async_writers = [] + + def _sync_io(self): + for writer in self.async_writers: + writer.synchronize() + writer.fp.close() + self.async_writers.clear() + + def _sync_d2h(self): + for writer in self.async_writers: + writer.sync_before_step() + + def synchronize(self): + """This method must be called before updating the model weights.""" + self._sync_d2h() + + def __del__(self): + self._sync_d2h() + self._sync_io() + def load_model( self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True ) -> Union[nn.Module, ModelWrapper]: @@ -111,6 +137,7 @@ class CheckpointIO(ABC): prefix: str = None, size_per_shard: int = 1024, use_safetensors: bool = False, + use_async: bool = False, ): """ Save model to checkpoint. @@ -138,11 +165,21 @@ class CheckpointIO(ABC): size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True. use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved """ + self._sync_io() + if use_async and not use_safetensors: + logger = get_dist_logger() + logger.warning( + "Async save is only supported when use_safetensors is set to True. " + "Setting use_safetensors to True for async save." + ) + use_safetensors = True if shard: - self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors) + self.save_sharded_model( + model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async=use_async + ) else: - self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) + self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async) def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024): """ @@ -234,6 +271,7 @@ class CheckpointIO(ABC): prefix: Optional[str], size_per_shard: int, use_safetensors: bool, + use_async: bool = False, ): """ Save model to sharded checkpoint. @@ -248,7 +286,9 @@ class CheckpointIO(ABC): """ @abstractmethod - def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + def save_unsharded_model( + self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False + ): """ Save model to unsharded checkpoint. diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 2534fa163..a4866e64c 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -8,9 +8,13 @@ from typing import Optional import torch.nn as nn from torch.optim import Optimizer +from colossalai.utils.safetensors import move_and_save + from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile from .utils import ( + async_save_state_dict_shards, + create_pinned_state_dict, get_model_base_filenames, get_optimizer_base_filenames, is_safetensors_available, @@ -40,15 +44,27 @@ class GeneralCheckpointIO(CheckpointIO): checkpoint = load_state_dict(checkpoint) model.load_state_dict(checkpoint, strict=strict) - def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + def save_unsharded_model( + self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False + ): state_dict = model.state_dict() # TODO(FrankLeeeee): add support for gather_dtensor if gather_dtensor: pass - # save the checkpoint - save_state_dict(state_dict, checkpoint, use_safetensors) + if use_async: + from tensornvme.async_file_io import AsyncFileWriter + + writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread") + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + self.async_writers.append(writer) + move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)]) + else: + + # save the checkpoint + save_state_dict(state_dict, checkpoint, use_safetensors) def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): """ @@ -151,6 +167,7 @@ class GeneralCheckpointIO(CheckpointIO): prefix: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False, + use_async: bool = False, ): """ implement this method as it can be supported by Huggingface model, @@ -168,16 +185,30 @@ class GeneralCheckpointIO(CheckpointIO): weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint_path) - # Save shards of optimizer states. - # In general cases, is_master is set to True to get the right behavior. - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint_path, - index_file=index_file, - base_filename=weights_name, - is_master=True, - use_safetensors=use_safetensors, - ) + if use_async: + pinned_state_dict = self.pinned_state_dicts.get(id(model), None) + total_size, new_pinned_state_dict, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=True, + pinned_state_dict=pinned_state_dict, + n_write_entries=self.N_WRITE_ENTRIES, + ) + self.pinned_state_dicts[id(model)] = new_pinned_state_dict + self.async_writers.extend(writers) + else: + # Save shards of optimizer states. + # In general cases, is_master is set to True to get the right behavior. + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=True, + use_safetensors=use_safetensors, + ) index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index b3917bd9d..6d539cce6 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -5,7 +5,7 @@ from collections import abc as container_abcs from collections import defaultdict from itertools import chain from pathlib import Path -from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple +from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple import torch import torch.nn as nn @@ -19,6 +19,7 @@ from colossalai.tensor.d_tensor import ( to_global, to_global_for_customized_distributed_tensor, ) +from colossalai.utils.safetensors import move_and_save SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -263,6 +264,71 @@ def save_state_dict_shards( return total_size +def async_save_state_dict_shards( + sharded_state_dict: Iterator[Tuple[OrderedDict, int]], + checkpoint: str, + index_file: "CheckpointIndexFile", + base_filename: str, + is_master: bool, + pinned_state_dict: Optional[Dict[str, torch.Tensor]], + n_write_entries: int, + use_pp_format: bool = False, +) -> Tuple[int, Dict[str, torch.Tensor], list]: + """ + Save sharded state dict only on master rank, this method can be used by both model and optimizer states. + Args: + sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size. + checkpoint (str): The path of checkpoint directory as string. + index_file (CheckpointIndexFile): The index file object to be updated. + base_filename (str): Decides the prefix of filenames of shards. + is_master (bool): Whether current rank is main process. + use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False. + use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False. + + Returns: + int: the total size of shards + """ + from tensornvme.async_file_io import AsyncFileWriter + + total_size = 0 + shard_filenames = [] + if pinned_state_dict is None: + returned_state_dict = {} + else: + returned_state_dict = pinned_state_dict + writers = [] + for idx, shard_pair in enumerate(sharded_state_dict): + shard, current_size = shard_pair + # Just loop over the sharder and gather to other ranks if not master + if not is_master: + del shard + continue + shard_file = get_shard_filename(base_filename, idx) + total_size = total_size + current_size + for key in shard.keys(): + index_file.append_weight_map(key, shard_file) + checkpoint_file_path = os.path.join(checkpoint, shard_file) + + writer = AsyncFileWriter(open(checkpoint_file_path, "wb"), n_write_entries, backend="pthread") + writers.append(writer) + + if pinned_state_dict is not None: + sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()} + else: + sub_pinned_state_dict = create_pinned_state_dict(shard) + returned_state_dict.update(sub_pinned_state_dict) + + # Only save on master rank. + move_and_save(writer, shard, sub_pinned_state_dict) + shard_filenames.append(shard_file) + del shard + + # Clean folder, deleted unneeded files. + clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format) + + return total_size, returned_state_dict, writers + + def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a @@ -799,3 +865,10 @@ def get_shard_filename(weights_name: str, idx: int): shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin") shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors") return shard_file + + +def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]): + pin_mem = dict() + for name, tensor in state_dict.items(): + pin_mem[name] = torch.empty_like(tensor, pin_memory=True, device="cpu") + return pin_mem diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index 9aa3558d9..bf8decd0f 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -1,7 +1,7 @@ # a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214 import json from dataclasses import asdict, dataclass -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import torch from safetensors.torch import _TYPES @@ -27,10 +27,11 @@ class PreparedData: offset: int -def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor]]: +def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor], List[str]]: sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0])) tensors = [] + tensor_keys = [] metadata = {} offset = 0 @@ -42,6 +43,7 @@ def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Ten offset += n metadata[name] = asdict(tensor_info) tensors.append(tensor) + tensor_keys.append(name) metadata_buf = json.dumps(metadata).encode("utf-8") @@ -50,11 +52,11 @@ def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Ten n = len(metadata_buf) - return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors + return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors, tensor_keys def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None: - prepared_data, tensors = prepare(state_dict) + prepared_data, tensors, _ = prepare(state_dict) n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset f_writer.write(n.to_bytes(8, byteorder="little")) @@ -62,3 +64,22 @@ def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None for tensor in tensors: f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset) + + +def move_and_save( + f_writer: AsyncFileWriter, + state_dict: Dict[str, torch.Tensor], + state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None, +) -> None: + prepared_data, _, tensor_keys = prepare(state_dict) + n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset + + f_writer.write(n.to_bytes(8, byteorder="little")) + f_writer.write(header_bytes) + + f_writer.register_h2d(len(tensor_keys)) + for name in tensor_keys: + if state_dict_pinned: + f_writer.write_tensor(state_dict[name], state_dict_pinned[name]) + else: + f_writer.write_tensor(state_dict[name])