[checkpointio] support debug log (#6153)

* [checkpointio] support debug log

* [checkpointio] refactor async writer api

* fix test

* fix test
pull/6155/head^2
Hongxin Liu 2024-12-02 11:29:19 +08:00 committed by GitHub
parent ab856fd308
commit 6280cb18b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 33 additions and 54 deletions

View File

@ -137,12 +137,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
state_dict = optimizer.state_dict(pinned_state_dicts, only_on_master=True)
if self.coordinator.is_master():
if use_async:
from tensornvme.async_file_io import AsyncFileWriter
from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(checkpoint, n_entries=self.N_WRITE_ENTRIES, backend="pthread")
save_nested(f_writer, state_dict)
f_writer = save_nested(checkpoint, state_dict)
self.async_writers.append(f_writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
@ -222,16 +220,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
checkpoint_file_path = os.path.join(checkpoint, shard_file)
if self.coordinator.is_master():
if use_async:
from tensornvme.async_file_io import AsyncFileWriter
from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(
checkpoint_file_path,
n_entries=self.N_WRITE_ENTRIES,
backend="pthread",
)
save_nested(f_writer, shard)
f_writer = save_nested(checkpoint_file_path, shard)
self.async_writers.append(f_writer)
else:
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)

View File

@ -59,8 +59,6 @@ class CheckpointIO(ABC):
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
"""
N_WRITE_ENTRIES: int = 32
# ======================================
# Public methods
# ======================================

View File

@ -54,13 +54,11 @@ class GeneralCheckpointIO(CheckpointIO):
pass
if use_async:
from tensornvme.async_file_io import AsyncFileWriter
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
else:
# save the checkpoint
@ -196,7 +194,6 @@ class GeneralCheckpointIO(CheckpointIO):
base_filename=weights_name,
is_master=True,
pinned_state_dict=pinned_state_dict,
n_write_entries=self.N_WRITE_ENTRIES,
)
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
self.async_writers.extend(writers)

View File

@ -686,15 +686,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
for _state_dict in state_dict_list:
complete_state_dict.update(_state_dict)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter
from colossalai.utils.safetensors import move_and_save
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
else:
save_state_dict(complete_state_dict, checkpoint, use_safetensors)

View File

@ -273,7 +273,6 @@ def async_save_state_dict_shards(
base_filename: str,
is_master: bool,
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
n_write_entries: int,
use_pp_format: bool = False,
) -> Tuple[int, Dict[str, torch.Tensor], list]:
"""
@ -290,7 +289,6 @@ def async_save_state_dict_shards(
Returns:
int: the total size of shards
"""
from tensornvme.async_file_io import AsyncFileWriter
total_size = 0
shard_filenames = []
@ -311,9 +309,6 @@ def async_save_state_dict_shards(
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
writer = AsyncFileWriter(checkpoint_file_path, n_write_entries, backend="pthread")
writers.append(writer)
if pinned_state_dict is not None:
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()}
else:
@ -321,7 +316,8 @@ def async_save_state_dict_shards(
returned_state_dict.update(sub_pinned_state_dict)
# Only save on master rank.
move_and_save(writer, shard, sub_pinned_state_dict)
writer = move_and_save(checkpoint_file_path, shard, sub_pinned_state_dict)
writers.append(writer)
shard_filenames.append(shard_file)
del shard

View File

@ -15,6 +15,8 @@ import io
from torch.distributed.distributed_c10d import _pickler, _unpickler
ASYNC_WRITE_ENTRIES = 32
def _object_to_tensor(obj, device):
f = io.BytesIO()
@ -149,32 +151,31 @@ def prepare(
return PreparedData(n=n, header_bytes=header_buf, offset=offset), tensors, tensor_keys
def save(
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
) -> None:
def save(path: str, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None) -> None:
prepared_data, tensors, _ = prepare(state_dict, metadata)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensors))
f_writer.write(n.to_bytes(8, byteorder="little"))
f_writer.write(header_bytes)
for tensor in tensors:
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
return f_writer
def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
def save_nested(path: str, state_dict: Dict[str, torch.Tensor]) -> None:
flatten_data, metadata = _flatten_optim_state_dict(state_dict)
save(f_writer, flatten_data, metadata)
return save(path, flatten_data, metadata)
def move_and_save(
f_writer: AsyncFileWriter,
path: str,
state_dict: Dict[str, torch.Tensor],
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
prepared_data, _, tensor_keys = prepare(state_dict)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensor_keys))
f_writer.write(n.to_bytes(8, byteorder="little"))
f_writer.write(header_bytes)
@ -184,6 +185,7 @@ def move_and_save(
f_writer.write_tensor(state_dict[name], state_dict_pinned[name])
else:
f_writer.write_tensor(state_dict[name])
return f_writer
def load_flat(checkpoint_path):

View File

@ -83,7 +83,11 @@ class TensorBucket:
unflat_buffers = list(map(list, zip(*unflat_buffers)))
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
write_back_tensor = self._write_back_pairs[tensor]
write_back_tensor.data.copy_(
_flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()].reshape_as(write_back_tensor)
)
rec_tensor = _flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()]
if write_back_tensor.is_contiguous():
rec_tensor = rec_tensor.view_as(write_back_tensor)
else:
rec_tensor = rec_tensor.reshape_as(write_back_tensor)
write_back_tensor.data.copy_(rec_tensor)
self.empty()

View File

@ -3,18 +3,12 @@ import tempfile
import torch
from safetensors.torch import load_file
from colossalai.testing import check_state_dict_equal, clear_cache_before_run
from colossalai.utils import get_current_device
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
try:
from tensornvme.async_file_io import AsyncFileWriter
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
from colossalai.testing import check_state_dict_equal
from colossalai.utils import get_current_device
@clear_cache_before_run()
def test_save_load():
with tempfile.TemporaryDirectory() as tempdir:
optimizer_state_dict = {
@ -111,8 +105,7 @@ def test_save_load():
}
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
f_writer = AsyncFileWriter(optimizer_saved_path, n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict)
f_writer = save_nested(optimizer_saved_path, optimizer_state_dict)
f_writer.sync_before_step()
f_writer.synchronize()
del f_writer
@ -120,8 +113,7 @@ def test_save_load():
check_state_dict_equal(load_state_dict, optimizer_state_dict)
optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
f_writer = AsyncFileWriter(optimizer_shard_saved_path, n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict["state"])
f_writer = save_nested(optimizer_shard_saved_path, optimizer_state_dict["state"])
f_writer.sync_before_step()
f_writer.synchronize()
del f_writer
@ -134,8 +126,7 @@ def test_save_load():
"module.weight2": torch.rand((1024, 1024)),
}
model_saved_path = f"{tempdir}/save_model.safetensors"
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
save(f_writer, model_state_dict)
f_writer = save(model_saved_path, model_state_dict)
f_writer.sync_before_step()
f_writer.synchronize()
del f_writer
@ -145,8 +136,7 @@ def test_save_load():
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
f_writer = move_and_save(model_saved_path, model_state_dict_cuda, model_state_pinned)
f_writer.sync_before_step()
f_writer.synchronize()
del f_writer

View File

@ -10,7 +10,7 @@ from colossalai.logging import disable_existing_loggers
from colossalai.nn.optimizer import DistributedLamb, Lamb
from colossalai.tensor.d_tensor import get_shard_dim_1d, is_distributed_tensor
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer
from tests.kit.model_zoo import model_zoo
@ -108,6 +108,7 @@ def set_dist_grad(
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
@parameterize("bias_correction", [False, True])
@parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)])
@clear_cache_before_run()
def run_dist_lamb_basic(
bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
) -> None:
@ -177,6 +178,7 @@ def run_dist_lamb_basic(
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
@parameterize("bias_correction", [False, True])
@parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)])
@clear_cache_before_run()
def run_dist_lamb_fwd_bwd(
bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
) -> None: