mirror of https://github.com/hpcaitech/ColossalAI
commit
8ecff0cb7f
|
@ -141,7 +141,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
|
|
||||||
from colossalai.utils.safetensors import save_nested
|
from colossalai.utils.safetensors import save_nested
|
||||||
|
|
||||||
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
|
f_writer = AsyncFileWriter(
|
||||||
|
fp=open(checkpoint, "wb", buffering=0), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
|
||||||
|
)
|
||||||
save_nested(f_writer, state_dict)
|
save_nested(f_writer, state_dict)
|
||||||
self.async_writers.append(f_writer)
|
self.async_writers.append(f_writer)
|
||||||
else:
|
else:
|
||||||
|
@ -225,7 +227,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
from colossalai.utils.safetensors import save_nested
|
from colossalai.utils.safetensors import save_nested
|
||||||
|
|
||||||
f_writer = AsyncFileWriter(
|
f_writer = AsyncFileWriter(
|
||||||
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
|
fp=open(checkpoint_file_path, "wb", buffering=0),
|
||||||
|
n_entries=self.N_WRITE_ENTRIES,
|
||||||
|
backend="pthread",
|
||||||
)
|
)
|
||||||
save_nested(f_writer, shard)
|
save_nested(f_writer, shard)
|
||||||
self.async_writers.append(f_writer)
|
self.async_writers.append(f_writer)
|
||||||
|
|
|
@ -56,7 +56,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
if use_async:
|
if use_async:
|
||||||
from tensornvme.async_file_io import AsyncFileWriter
|
from tensornvme.async_file_io import AsyncFileWriter
|
||||||
|
|
||||||
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
|
writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread")
|
||||||
if id(model) not in self.pinned_state_dicts:
|
if id(model) not in self.pinned_state_dicts:
|
||||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||||
self.async_writers.append(writer)
|
self.async_writers.append(writer)
|
||||||
|
|
|
@ -690,7 +690,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
|
|
||||||
from colossalai.utils.safetensors import move_and_save
|
from colossalai.utils.safetensors import move_and_save
|
||||||
|
|
||||||
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
|
writer = AsyncFileWriter(
|
||||||
|
open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread"
|
||||||
|
)
|
||||||
if id(model) not in self.pinned_state_dicts:
|
if id(model) not in self.pinned_state_dicts:
|
||||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||||
self.async_writers.append(writer)
|
self.async_writers.append(writer)
|
||||||
|
|
|
@ -311,7 +311,7 @@ def async_save_state_dict_shards(
|
||||||
index_file.append_weight_map(key, shard_file)
|
index_file.append_weight_map(key, shard_file)
|
||||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||||
|
|
||||||
writer = AsyncFileWriter(open(checkpoint_file_path, "wb"), n_write_entries, backend="pthread")
|
writer = AsyncFileWriter(open(checkpoint_file_path, "wb", buffering=0), n_write_entries, backend="pthread")
|
||||||
writers.append(writer)
|
writers.append(writer)
|
||||||
|
|
||||||
if pinned_state_dict is not None:
|
if pinned_state_dict is not None:
|
||||||
|
|
Loading…
Reference in New Issue