Merge pull request #6149 from ver217/hotfix/ckpt

[checkpointio] disable buffering
pull/6151/head
Wang Binluo 2024-11-21 16:05:19 +08:00 committed by GitHub
commit 8ecff0cb7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 11 additions and 5 deletions

View File

@ -141,7 +141,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from colossalai.utils.safetensors import save_nested from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread") f_writer = AsyncFileWriter(
fp=open(checkpoint, "wb", buffering=0), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
save_nested(f_writer, state_dict) save_nested(f_writer, state_dict)
self.async_writers.append(f_writer) self.async_writers.append(f_writer)
else: else:
@ -225,7 +227,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from colossalai.utils.safetensors import save_nested from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter( f_writer = AsyncFileWriter(
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread" fp=open(checkpoint_file_path, "wb", buffering=0),
n_entries=self.N_WRITE_ENTRIES,
backend="pthread",
) )
save_nested(f_writer, shard) save_nested(f_writer, shard)
self.async_writers.append(f_writer) self.async_writers.append(f_writer)

View File

@ -56,7 +56,7 @@ class GeneralCheckpointIO(CheckpointIO):
if use_async: if use_async:
from tensornvme.async_file_io import AsyncFileWriter from tensornvme.async_file_io import AsyncFileWriter
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread") writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts: if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer) self.async_writers.append(writer)

View File

@ -690,7 +690,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
from colossalai.utils.safetensors import move_and_save from colossalai.utils.safetensors import move_and_save
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread") writer = AsyncFileWriter(
open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread"
)
if id(model) not in self.pinned_state_dicts: if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer) self.async_writers.append(writer)

View File

@ -311,7 +311,7 @@ def async_save_state_dict_shards(
index_file.append_weight_map(key, shard_file) index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file) checkpoint_file_path = os.path.join(checkpoint, shard_file)
writer = AsyncFileWriter(open(checkpoint_file_path, "wb"), n_write_entries, backend="pthread") writer = AsyncFileWriter(open(checkpoint_file_path, "wb", buffering=0), n_write_entries, backend="pthread")
writers.append(writer) writers.append(writer)
if pinned_state_dict is not None: if pinned_state_dict is not None: