Browse Source

[checkpointio] disable buffering

pull/6149/head
ver217 15 hours ago
parent
commit
8fddbab04c
  1. 8
      colossalai/booster/plugin/low_level_zero_plugin.py
  2. 2
      colossalai/checkpoint_io/general_checkpoint_io.py
  3. 4
      colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
  4. 2
      colossalai/checkpoint_io/utils.py

8
colossalai/booster/plugin/low_level_zero_plugin.py

@ -141,7 +141,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
f_writer = AsyncFileWriter(
fp=open(checkpoint, "wb", buffering=0), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
save_nested(f_writer, state_dict)
self.async_writers.append(f_writer)
else:
@ -225,7 +227,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
fp=open(checkpoint_file_path, "wb", buffering=0),
n_entries=self.N_WRITE_ENTRIES,
backend="pthread",
)
save_nested(f_writer, shard)
self.async_writers.append(f_writer)

2
colossalai/checkpoint_io/general_checkpoint_io.py

@ -56,7 +56,7 @@ class GeneralCheckpointIO(CheckpointIO):
if use_async:
from tensornvme.async_file_io import AsyncFileWriter
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)

4
colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

@ -690,7 +690,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
from colossalai.utils.safetensors import move_and_save
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
writer = AsyncFileWriter(
open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread"
)
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)

2
colossalai/checkpoint_io/utils.py

@ -311,7 +311,7 @@ def async_save_state_dict_shards(
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
writer = AsyncFileWriter(open(checkpoint_file_path, "wb"), n_write_entries, backend="pthread")
writer = AsyncFileWriter(open(checkpoint_file_path, "wb", buffering=0), n_write_entries, backend="pthread")
writers.append(writer)
if pinned_state_dict is not None:

Loading…
Cancel
Save