ckpt_api
wangbluo 2024-11-25 17:12:29 +08:00
parent 82c88c1e0d
commit d6af7be06e
3 changed files with 46 additions and 12 deletions

View File

@ -8,13 +8,11 @@ from typing import Optional
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.utils.safetensors import move_and_save
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
async_save_state_dict,
async_save_state_dict_shards,
create_pinned_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
@ -59,13 +57,16 @@ class GeneralCheckpointIO(CheckpointIO):
pass
if use_async:
from tensornvme.async_file_io import AsyncFileWriter
writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
new_pinned_state_dict, writers = async_save_state_dict(
state_dict,
checkpoint,
pinned_state_dict,
self.N_WRITE_ENTRIES,
shard_preprocess=False,
)
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
# save the checkpoint

View File

@ -19,7 +19,7 @@ from colossalai.tensor.d_tensor import (
to_global,
to_global_for_customized_distributed_tensor,
)
from colossalai.utils.safetensors import move_and_save
from colossalai.utils.safetensors import move_and_save, save
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
@ -266,6 +266,38 @@ def save_state_dict_shards(
return total_size
def async_save_state_dict(
state_dict: dict,
checkpoint_file_path: str,
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
n_write_entries: int,
shard_preprocess: bool = False,
move: bool = True,
):
from tensornvme.async_file_io import AsyncFileWriter
async_writers = []
saved_state_dict, metadata = state_dict, None
if pinned_state_dict is None:
pinned_state_dict = create_pinned_state_dict(saved_state_dict)
f_writer = AsyncFileWriter(fp=open(checkpoint_file_path, "wb"), n_entries=n_write_entries, backend="pthread")
if move:
move_and_save(
f_writer,
state_dict=saved_state_dict,
metadata=metadata,
state_dict_pinned=pinned_state_dict,
)
else:
for name, tensor in saved_state_dict.items():
pinned_state_dict[name].copy_(tensor)
save(f_writer=f_writer, state_dict=pinned_state_dict, metadata=metadata)
async_writers.append(f_writer)
return pinned_state_dict, async_writers
def async_save_state_dict_shards(
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,

View File

@ -170,9 +170,10 @@ def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor])
def move_and_save(
f_writer: AsyncFileWriter,
state_dict: Dict[str, torch.Tensor],
metadata: Optional[Dict[str, str]] = None,
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
prepared_data, _, tensor_keys = prepare(state_dict)
prepared_data, _, tensor_keys = prepare(state_dict, metadata)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
f_writer.write(n.to_bytes(8, byteorder="little"))