Browse Source

Merge pull request #6149 from ver217/hotfix/ckpt

[checkpointio] disable buffering
main
Wang Binluo 18 hours ago committed by GitHub
parent
commit
8ecff0cb7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 8
      colossalai/booster/plugin/low_level_zero_plugin.py
  2. 2
      colossalai/checkpoint_io/general_checkpoint_io.py
  3. 4
      colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
  4. 2
      colossalai/checkpoint_io/utils.py

8
colossalai/booster/plugin/low_level_zero_plugin.py

@ -141,7 +141,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from colossalai.utils.safetensors import save_nested from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread") f_writer = AsyncFileWriter(
fp=open(checkpoint, "wb", buffering=0), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
save_nested(f_writer, state_dict) save_nested(f_writer, state_dict)
self.async_writers.append(f_writer) self.async_writers.append(f_writer)
else: else:
@ -225,7 +227,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from colossalai.utils.safetensors import save_nested from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter( f_writer = AsyncFileWriter(
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread" fp=open(checkpoint_file_path, "wb", buffering=0),
n_entries=self.N_WRITE_ENTRIES,
backend="pthread",
) )
save_nested(f_writer, shard) save_nested(f_writer, shard)
self.async_writers.append(f_writer) self.async_writers.append(f_writer)

2
colossalai/checkpoint_io/general_checkpoint_io.py

@ -56,7 +56,7 @@ class GeneralCheckpointIO(CheckpointIO):
if use_async: if use_async:
from tensornvme.async_file_io import AsyncFileWriter from tensornvme.async_file_io import AsyncFileWriter
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread") writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts: if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer) self.async_writers.append(writer)

4
colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

@ -690,7 +690,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
from colossalai.utils.safetensors import move_and_save from colossalai.utils.safetensors import move_and_save
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread") writer = AsyncFileWriter(
open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread"
)
if id(model) not in self.pinned_state_dicts: if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer) self.async_writers.append(writer)

2
colossalai/checkpoint_io/utils.py

@ -311,7 +311,7 @@ def async_save_state_dict_shards(
index_file.append_weight_map(key, shard_file) index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file) checkpoint_file_path = os.path.join(checkpoint, shard_file)
writer = AsyncFileWriter(open(checkpoint_file_path, "wb"), n_write_entries, backend="pthread") writer = AsyncFileWriter(open(checkpoint_file_path, "wb", buffering=0), n_write_entries, backend="pthread")
writers.append(writer) writers.append(writer)
if pinned_state_dict is not None: if pinned_state_dict is not None:

Loading…
Cancel
Save