mirror of https://github.com/hpcaitech/ColossalAI
fix
parent
82c88c1e0d
commit
d6af7be06e
|
@ -8,13 +8,11 @@ from typing import Optional
|
|||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.utils.safetensors import move_and_save
|
||||
|
||||
from .checkpoint_io_base import CheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
from .utils import (
|
||||
async_save_state_dict,
|
||||
async_save_state_dict_shards,
|
||||
create_pinned_state_dict,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
is_safetensors_available,
|
||||
|
@ -59,13 +57,16 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
pass
|
||||
|
||||
if use_async:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread")
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
self.async_writers.append(writer)
|
||||
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
|
||||
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
|
||||
new_pinned_state_dict, writers = async_save_state_dict(
|
||||
state_dict,
|
||||
checkpoint,
|
||||
pinned_state_dict,
|
||||
self.N_WRITE_ENTRIES,
|
||||
shard_preprocess=False,
|
||||
)
|
||||
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
|
||||
self.async_writers.extend(writers)
|
||||
|
||||
else:
|
||||
# save the checkpoint
|
||||
|
|
|
@ -19,7 +19,7 @@ from colossalai.tensor.d_tensor import (
|
|||
to_global,
|
||||
to_global_for_customized_distributed_tensor,
|
||||
)
|
||||
from colossalai.utils.safetensors import move_and_save
|
||||
from colossalai.utils.safetensors import move_and_save, save
|
||||
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
@ -266,6 +266,38 @@ def save_state_dict_shards(
|
|||
return total_size
|
||||
|
||||
|
||||
def async_save_state_dict(
|
||||
state_dict: dict,
|
||||
checkpoint_file_path: str,
|
||||
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
|
||||
n_write_entries: int,
|
||||
shard_preprocess: bool = False,
|
||||
move: bool = True,
|
||||
):
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
async_writers = []
|
||||
|
||||
saved_state_dict, metadata = state_dict, None
|
||||
if pinned_state_dict is None:
|
||||
pinned_state_dict = create_pinned_state_dict(saved_state_dict)
|
||||
|
||||
f_writer = AsyncFileWriter(fp=open(checkpoint_file_path, "wb"), n_entries=n_write_entries, backend="pthread")
|
||||
if move:
|
||||
move_and_save(
|
||||
f_writer,
|
||||
state_dict=saved_state_dict,
|
||||
metadata=metadata,
|
||||
state_dict_pinned=pinned_state_dict,
|
||||
)
|
||||
else:
|
||||
for name, tensor in saved_state_dict.items():
|
||||
pinned_state_dict[name].copy_(tensor)
|
||||
save(f_writer=f_writer, state_dict=pinned_state_dict, metadata=metadata)
|
||||
async_writers.append(f_writer)
|
||||
return pinned_state_dict, async_writers
|
||||
|
||||
|
||||
def async_save_state_dict_shards(
|
||||
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
|
||||
checkpoint: str,
|
||||
|
|
|
@ -170,9 +170,10 @@ def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor])
|
|||
def move_and_save(
|
||||
f_writer: AsyncFileWriter,
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> None:
|
||||
prepared_data, _, tensor_keys = prepare(state_dict)
|
||||
prepared_data, _, tensor_keys = prepare(state_dict, metadata)
|
||||
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
||||
|
||||
f_writer.write(n.to_bytes(8, byteorder="little"))
|
||||
|
|
Loading…
Reference in New Issue