From 8fddbab04c410eab25e4279e5d1d6ce78e6d7776 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 21 Nov 2024 14:33:26 +0800 Subject: [PATCH] [checkpointio] disable buffering --- colossalai/booster/plugin/low_level_zero_plugin.py | 8 ++++++-- colossalai/checkpoint_io/general_checkpoint_io.py | 2 +- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 4 +++- colossalai/checkpoint_io/utils.py | 2 +- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 761947344..16bb2e9b8 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -141,7 +141,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): from colossalai.utils.safetensors import save_nested - f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread") + f_writer = AsyncFileWriter( + fp=open(checkpoint, "wb", buffering=0), n_entries=self.N_WRITE_ENTRIES, backend="pthread" + ) save_nested(f_writer, state_dict) self.async_writers.append(f_writer) else: @@ -225,7 +227,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): from colossalai.utils.safetensors import save_nested f_writer = AsyncFileWriter( - fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread" + fp=open(checkpoint_file_path, "wb", buffering=0), + n_entries=self.N_WRITE_ENTRIES, + backend="pthread", ) save_nested(f_writer, shard) self.async_writers.append(f_writer) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index a2d1dd158..ddfe5502f 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -56,7 +56,7 @@ class GeneralCheckpointIO(CheckpointIO): if use_async: from tensornvme.async_file_io import AsyncFileWriter - writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread") + writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread") if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) self.async_writers.append(writer) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index d66171c58..581575058 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -690,7 +690,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): from colossalai.utils.safetensors import move_and_save - writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread") + writer = AsyncFileWriter( + open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread" + ) if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) self.async_writers.append(writer) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 09cce3059..eb8bb2dcf 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -311,7 +311,7 @@ def async_save_state_dict_shards( index_file.append_weight_map(key, shard_file) checkpoint_file_path = os.path.join(checkpoint, shard_file) - writer = AsyncFileWriter(open(checkpoint_file_path, "wb"), n_write_entries, backend="pthread") + writer = AsyncFileWriter(open(checkpoint_file_path, "wb", buffering=0), n_write_entries, backend="pthread") writers.append(writer) if pinned_state_dict is not None: