From d6af7be06e851d091a1a62613aeb7ab0e4f49ad7 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Mon, 25 Nov 2024 17:12:29 +0800 Subject: [PATCH] fix --- .../checkpoint_io/general_checkpoint_io.py | 21 ++++++------ colossalai/checkpoint_io/utils.py | 34 ++++++++++++++++++- colossalai/utils/safetensors.py | 3 +- 3 files changed, 46 insertions(+), 12 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 99e77f7b9..70ad39b67 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -8,13 +8,11 @@ from typing import Optional import torch.nn as nn from torch.optim import Optimizer -from colossalai.utils.safetensors import move_and_save - from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile from .utils import ( + async_save_state_dict, async_save_state_dict_shards, - create_pinned_state_dict, get_model_base_filenames, get_optimizer_base_filenames, is_safetensors_available, @@ -59,13 +57,16 @@ class GeneralCheckpointIO(CheckpointIO): pass if use_async: - from tensornvme.async_file_io import AsyncFileWriter - - writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread") - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) - self.async_writers.append(writer) - move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)]) + pinned_state_dict = self.pinned_state_dicts.get(id(model), None) + new_pinned_state_dict, writers = async_save_state_dict( + state_dict, + checkpoint, + pinned_state_dict, + self.N_WRITE_ENTRIES, + shard_preprocess=False, + ) + self.pinned_state_dicts[id(model)] = new_pinned_state_dict + self.async_writers.extend(writers) else: # save the checkpoint diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index eb8bb2dcf..c8a4c6c1d 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -19,7 +19,7 @@ from colossalai.tensor.d_tensor import ( to_global, to_global_for_customized_distributed_tensor, ) -from colossalai.utils.safetensors import move_and_save +from colossalai.utils.safetensors import move_and_save, save SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -266,6 +266,38 @@ def save_state_dict_shards( return total_size +def async_save_state_dict( + state_dict: dict, + checkpoint_file_path: str, + pinned_state_dict: Optional[Dict[str, torch.Tensor]], + n_write_entries: int, + shard_preprocess: bool = False, + move: bool = True, +): + from tensornvme.async_file_io import AsyncFileWriter + + async_writers = [] + + saved_state_dict, metadata = state_dict, None + if pinned_state_dict is None: + pinned_state_dict = create_pinned_state_dict(saved_state_dict) + + f_writer = AsyncFileWriter(fp=open(checkpoint_file_path, "wb"), n_entries=n_write_entries, backend="pthread") + if move: + move_and_save( + f_writer, + state_dict=saved_state_dict, + metadata=metadata, + state_dict_pinned=pinned_state_dict, + ) + else: + for name, tensor in saved_state_dict.items(): + pinned_state_dict[name].copy_(tensor) + save(f_writer=f_writer, state_dict=pinned_state_dict, metadata=metadata) + async_writers.append(f_writer) + return pinned_state_dict, async_writers + + def async_save_state_dict_shards( sharded_state_dict: Iterator[Tuple[OrderedDict, int]], checkpoint: str, diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index 8b8cb627f..39ef8bbc9 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -170,9 +170,10 @@ def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) def move_and_save( f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], + metadata: Optional[Dict[str, str]] = None, state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None, ) -> None: - prepared_data, _, tensor_keys = prepare(state_dict) + prepared_data, _, tensor_keys = prepare(state_dict, metadata) n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset f_writer.write(n.to_bytes(8, byteorder="little"))