mirror of https://github.com/hpcaitech/ColossalAI
[checkpointio] support async model save (#6131)
* [checkpointio] support async model save * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/6147/head
parent
5a03d2696d
commit
d4a436051d
|
@ -310,6 +310,7 @@ class Booster:
|
|||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
use_async: bool = False,
|
||||
) -> None:
|
||||
"""Save model to checkpoint.
|
||||
|
||||
|
@ -333,6 +334,7 @@ class Booster:
|
|||
prefix=prefix,
|
||||
size_per_shard=size_per_shard,
|
||||
use_safetensors=use_safetensors,
|
||||
use_async=use_async,
|
||||
)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
|
||||
|
|
|
@ -259,10 +259,12 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
model.update_master_params()
|
||||
|
||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
def save_unsharded_model(
|
||||
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
|
@ -272,11 +274,12 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
use_async: bool = False,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
return super().save_sharded_model(
|
||||
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
|
||||
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async=use_async
|
||||
)
|
||||
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
||||
|
|
|
@ -33,13 +33,17 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
def save_unsharded_model(
|
||||
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
||||
):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
if self.coordinator.is_master():
|
||||
super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors)
|
||||
super().save_unsharded_model(
|
||||
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
|
||||
)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
"""
|
||||
|
@ -71,6 +75,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
|
@ -78,7 +83,13 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
if self.coordinator.is_master():
|
||||
super().save_sharded_model(
|
||||
model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
|
||||
model.unwrap(),
|
||||
checkpoint_path,
|
||||
gather_dtensor,
|
||||
prefix,
|
||||
max_shard_size,
|
||||
use_safetensors,
|
||||
use_async=use_async,
|
||||
)
|
||||
|
||||
def load_sharded_model(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -8,6 +8,7 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
|
||||
|
||||
|
@ -58,9 +59,34 @@ class CheckpointIO(ABC):
|
|||
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
|
||||
"""
|
||||
|
||||
N_WRITE_ENTRIES: int = 32
|
||||
|
||||
# ======================================
|
||||
# Public methods
|
||||
# ======================================
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pinned_state_dicts: Dict[int, dict] = {}
|
||||
self.async_writers = []
|
||||
|
||||
def _sync_io(self):
|
||||
for writer in self.async_writers:
|
||||
writer.synchronize()
|
||||
writer.fp.close()
|
||||
self.async_writers.clear()
|
||||
|
||||
def _sync_d2h(self):
|
||||
for writer in self.async_writers:
|
||||
writer.sync_before_step()
|
||||
|
||||
def synchronize(self):
|
||||
"""This method must be called before updating the model weights."""
|
||||
self._sync_d2h()
|
||||
|
||||
def __del__(self):
|
||||
self._sync_d2h()
|
||||
self._sync_io()
|
||||
|
||||
def load_model(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
|
||||
) -> Union[nn.Module, ModelWrapper]:
|
||||
|
@ -111,6 +137,7 @@ class CheckpointIO(ABC):
|
|||
prefix: str = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save model to checkpoint.
|
||||
|
@ -138,11 +165,21 @@ class CheckpointIO(ABC):
|
|||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
||||
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
||||
"""
|
||||
self._sync_io()
|
||||
if use_async and not use_safetensors:
|
||||
logger = get_dist_logger()
|
||||
logger.warning(
|
||||
"Async save is only supported when use_safetensors is set to True. "
|
||||
"Setting use_safetensors to True for async save."
|
||||
)
|
||||
use_safetensors = True
|
||||
|
||||
if shard:
|
||||
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
|
||||
self.save_sharded_model(
|
||||
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async=use_async
|
||||
)
|
||||
else:
|
||||
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
|
||||
"""
|
||||
|
@ -234,6 +271,7 @@ class CheckpointIO(ABC):
|
|||
prefix: Optional[str],
|
||||
size_per_shard: int,
|
||||
use_safetensors: bool,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save model to sharded checkpoint.
|
||||
|
@ -248,7 +286,9 @@ class CheckpointIO(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
def save_unsharded_model(
|
||||
self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
||||
):
|
||||
"""
|
||||
Save model to unsharded checkpoint.
|
||||
|
||||
|
|
|
@ -8,9 +8,13 @@ from typing import Optional
|
|||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.utils.safetensors import move_and_save
|
||||
|
||||
from .checkpoint_io_base import CheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
from .utils import (
|
||||
async_save_state_dict_shards,
|
||||
create_pinned_state_dict,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
is_safetensors_available,
|
||||
|
@ -40,15 +44,27 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
checkpoint = load_state_dict(checkpoint)
|
||||
model.load_state_dict(checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
def save_unsharded_model(
|
||||
self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
||||
):
|
||||
state_dict = model.state_dict()
|
||||
|
||||
# TODO(FrankLeeeee): add support for gather_dtensor
|
||||
if gather_dtensor:
|
||||
pass
|
||||
|
||||
# save the checkpoint
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
if use_async:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
self.async_writers.append(writer)
|
||||
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
|
||||
else:
|
||||
|
||||
# save the checkpoint
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||
"""
|
||||
|
@ -151,6 +167,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
implement this method as it can be supported by Huggingface model,
|
||||
|
@ -168,16 +185,30 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
|
||||
# Save shards of optimizer states.
|
||||
# In general cases, is_master is set to True to get the right behavior.
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=True,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
if use_async:
|
||||
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
|
||||
total_size, new_pinned_state_dict, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=True,
|
||||
pinned_state_dict=pinned_state_dict,
|
||||
n_write_entries=self.N_WRITE_ENTRIES,
|
||||
)
|
||||
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
# Save shards of optimizer states.
|
||||
# In general cases, is_master is set to True to get the right behavior.
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=True,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
|
|
|
@ -5,7 +5,7 @@ from collections import abc as container_abcs
|
|||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
||||
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -19,6 +19,7 @@ from colossalai.tensor.d_tensor import (
|
|||
to_global,
|
||||
to_global_for_customized_distributed_tensor,
|
||||
)
|
||||
from colossalai.utils.safetensors import move_and_save
|
||||
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
@ -263,6 +264,71 @@ def save_state_dict_shards(
|
|||
return total_size
|
||||
|
||||
|
||||
def async_save_state_dict_shards(
|
||||
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
|
||||
checkpoint: str,
|
||||
index_file: "CheckpointIndexFile",
|
||||
base_filename: str,
|
||||
is_master: bool,
|
||||
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
|
||||
n_write_entries: int,
|
||||
use_pp_format: bool = False,
|
||||
) -> Tuple[int, Dict[str, torch.Tensor], list]:
|
||||
"""
|
||||
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
|
||||
Args:
|
||||
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
|
||||
checkpoint (str): The path of checkpoint directory as string.
|
||||
index_file (CheckpointIndexFile): The index file object to be updated.
|
||||
base_filename (str): Decides the prefix of filenames of shards.
|
||||
is_master (bool): Whether current rank is main process.
|
||||
use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
|
||||
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
|
||||
|
||||
Returns:
|
||||
int: the total size of shards
|
||||
"""
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
total_size = 0
|
||||
shard_filenames = []
|
||||
if pinned_state_dict is None:
|
||||
returned_state_dict = {}
|
||||
else:
|
||||
returned_state_dict = pinned_state_dict
|
||||
writers = []
|
||||
for idx, shard_pair in enumerate(sharded_state_dict):
|
||||
shard, current_size = shard_pair
|
||||
# Just loop over the sharder and gather to other ranks if not master
|
||||
if not is_master:
|
||||
del shard
|
||||
continue
|
||||
shard_file = get_shard_filename(base_filename, idx)
|
||||
total_size = total_size + current_size
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
|
||||
writer = AsyncFileWriter(open(checkpoint_file_path, "wb"), n_write_entries, backend="pthread")
|
||||
writers.append(writer)
|
||||
|
||||
if pinned_state_dict is not None:
|
||||
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()}
|
||||
else:
|
||||
sub_pinned_state_dict = create_pinned_state_dict(shard)
|
||||
returned_state_dict.update(sub_pinned_state_dict)
|
||||
|
||||
# Only save on master rank.
|
||||
move_and_save(writer, shard, sub_pinned_state_dict)
|
||||
shard_filenames.append(shard_file)
|
||||
del shard
|
||||
|
||||
# Clean folder, deleted unneeded files.
|
||||
clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
|
||||
|
||||
return total_size, returned_state_dict, writers
|
||||
|
||||
|
||||
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
|
@ -799,3 +865,10 @@ def get_shard_filename(weights_name: str, idx: int):
|
|||
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
|
||||
shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors")
|
||||
return shard_file
|
||||
|
||||
|
||||
def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]):
|
||||
pin_mem = dict()
|
||||
for name, tensor in state_dict.items():
|
||||
pin_mem[name] = torch.empty_like(tensor, pin_memory=True, device="cpu")
|
||||
return pin_mem
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES
|
||||
|
@ -27,10 +27,11 @@ class PreparedData:
|
|||
offset: int
|
||||
|
||||
|
||||
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor]]:
|
||||
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor], List[str]]:
|
||||
sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0]))
|
||||
|
||||
tensors = []
|
||||
tensor_keys = []
|
||||
metadata = {}
|
||||
offset = 0
|
||||
|
||||
|
@ -42,6 +43,7 @@ def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Ten
|
|||
offset += n
|
||||
metadata[name] = asdict(tensor_info)
|
||||
tensors.append(tensor)
|
||||
tensor_keys.append(name)
|
||||
|
||||
metadata_buf = json.dumps(metadata).encode("utf-8")
|
||||
|
||||
|
@ -50,11 +52,11 @@ def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Ten
|
|||
|
||||
n = len(metadata_buf)
|
||||
|
||||
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors
|
||||
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors, tensor_keys
|
||||
|
||||
|
||||
def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||
prepared_data, tensors = prepare(state_dict)
|
||||
prepared_data, tensors, _ = prepare(state_dict)
|
||||
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
||||
|
||||
f_writer.write(n.to_bytes(8, byteorder="little"))
|
||||
|
@ -62,3 +64,22 @@ def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None
|
|||
|
||||
for tensor in tensors:
|
||||
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
|
||||
|
||||
|
||||
def move_and_save(
|
||||
f_writer: AsyncFileWriter,
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> None:
|
||||
prepared_data, _, tensor_keys = prepare(state_dict)
|
||||
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
||||
|
||||
f_writer.write(n.to_bytes(8, byteorder="little"))
|
||||
f_writer.write(header_bytes)
|
||||
|
||||
f_writer.register_h2d(len(tensor_keys))
|
||||
for name in tensor_keys:
|
||||
if state_dict_pinned:
|
||||
f_writer.write_tensor(state_dict[name], state_dict_pinned[name])
|
||||
else:
|
||||
f_writer.write_tensor(state_dict[name])
|
||||
|
|
Loading…
Reference in New Issue