mirror of https://github.com/hpcaitech/ColossalAI
[checkpointio] support debug log (#6153)
* [checkpointio] support debug log * [checkpointio] refactor async writer api * fix test * fix testpull/6155/head^2
parent
ab856fd308
commit
6280cb18b8
|
@ -137,12 +137,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||
state_dict = optimizer.state_dict(pinned_state_dicts, only_on_master=True)
|
||||
if self.coordinator.is_master():
|
||||
if use_async:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
from colossalai.utils.safetensors import save_nested
|
||||
|
||||
f_writer = AsyncFileWriter(checkpoint, n_entries=self.N_WRITE_ENTRIES, backend="pthread")
|
||||
save_nested(f_writer, state_dict)
|
||||
f_writer = save_nested(checkpoint, state_dict)
|
||||
self.async_writers.append(f_writer)
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
@ -222,16 +220,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
if self.coordinator.is_master():
|
||||
if use_async:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
from colossalai.utils.safetensors import save_nested
|
||||
|
||||
f_writer = AsyncFileWriter(
|
||||
checkpoint_file_path,
|
||||
n_entries=self.N_WRITE_ENTRIES,
|
||||
backend="pthread",
|
||||
)
|
||||
save_nested(f_writer, shard)
|
||||
f_writer = save_nested(checkpoint_file_path, shard)
|
||||
self.async_writers.append(f_writer)
|
||||
else:
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||
|
|
|
@ -59,8 +59,6 @@ class CheckpointIO(ABC):
|
|||
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
|
||||
"""
|
||||
|
||||
N_WRITE_ENTRIES: int = 32
|
||||
|
||||
# ======================================
|
||||
# Public methods
|
||||
# ======================================
|
||||
|
|
|
@ -54,13 +54,11 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
pass
|
||||
|
||||
if use_async:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
|
||||
self.async_writers.append(writer)
|
||||
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
|
||||
|
||||
else:
|
||||
# save the checkpoint
|
||||
|
@ -196,7 +194,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
base_filename=weights_name,
|
||||
is_master=True,
|
||||
pinned_state_dict=pinned_state_dict,
|
||||
n_write_entries=self.N_WRITE_ENTRIES,
|
||||
)
|
||||
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
|
||||
self.async_writers.extend(writers)
|
||||
|
|
|
@ -686,15 +686,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
for _state_dict in state_dict_list:
|
||||
complete_state_dict.update(_state_dict)
|
||||
if use_async:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
from colossalai.utils.safetensors import move_and_save
|
||||
|
||||
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
|
||||
self.async_writers.append(writer)
|
||||
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
|
||||
else:
|
||||
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
||||
|
||||
|
|
|
@ -273,7 +273,6 @@ def async_save_state_dict_shards(
|
|||
base_filename: str,
|
||||
is_master: bool,
|
||||
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
|
||||
n_write_entries: int,
|
||||
use_pp_format: bool = False,
|
||||
) -> Tuple[int, Dict[str, torch.Tensor], list]:
|
||||
"""
|
||||
|
@ -290,7 +289,6 @@ def async_save_state_dict_shards(
|
|||
Returns:
|
||||
int: the total size of shards
|
||||
"""
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
total_size = 0
|
||||
shard_filenames = []
|
||||
|
@ -311,9 +309,6 @@ def async_save_state_dict_shards(
|
|||
index_file.append_weight_map(key, shard_file)
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
|
||||
writer = AsyncFileWriter(checkpoint_file_path, n_write_entries, backend="pthread")
|
||||
writers.append(writer)
|
||||
|
||||
if pinned_state_dict is not None:
|
||||
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()}
|
||||
else:
|
||||
|
@ -321,7 +316,8 @@ def async_save_state_dict_shards(
|
|||
returned_state_dict.update(sub_pinned_state_dict)
|
||||
|
||||
# Only save on master rank.
|
||||
move_and_save(writer, shard, sub_pinned_state_dict)
|
||||
writer = move_and_save(checkpoint_file_path, shard, sub_pinned_state_dict)
|
||||
writers.append(writer)
|
||||
shard_filenames.append(shard_file)
|
||||
del shard
|
||||
|
||||
|
|
|
@ -15,6 +15,8 @@ import io
|
|||
|
||||
from torch.distributed.distributed_c10d import _pickler, _unpickler
|
||||
|
||||
ASYNC_WRITE_ENTRIES = 32
|
||||
|
||||
|
||||
def _object_to_tensor(obj, device):
|
||||
f = io.BytesIO()
|
||||
|
@ -149,32 +151,31 @@ def prepare(
|
|||
return PreparedData(n=n, header_bytes=header_buf, offset=offset), tensors, tensor_keys
|
||||
|
||||
|
||||
def save(
|
||||
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
|
||||
) -> None:
|
||||
def save(path: str, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None) -> None:
|
||||
prepared_data, tensors, _ = prepare(state_dict, metadata)
|
||||
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
||||
|
||||
f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensors))
|
||||
f_writer.write(n.to_bytes(8, byteorder="little"))
|
||||
f_writer.write(header_bytes)
|
||||
|
||||
for tensor in tensors:
|
||||
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
|
||||
return f_writer
|
||||
|
||||
|
||||
def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||
def save_nested(path: str, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||
flatten_data, metadata = _flatten_optim_state_dict(state_dict)
|
||||
save(f_writer, flatten_data, metadata)
|
||||
return save(path, flatten_data, metadata)
|
||||
|
||||
|
||||
def move_and_save(
|
||||
f_writer: AsyncFileWriter,
|
||||
path: str,
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> None:
|
||||
prepared_data, _, tensor_keys = prepare(state_dict)
|
||||
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
||||
|
||||
f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensor_keys))
|
||||
f_writer.write(n.to_bytes(8, byteorder="little"))
|
||||
f_writer.write(header_bytes)
|
||||
|
||||
|
@ -184,6 +185,7 @@ def move_and_save(
|
|||
f_writer.write_tensor(state_dict[name], state_dict_pinned[name])
|
||||
else:
|
||||
f_writer.write_tensor(state_dict[name])
|
||||
return f_writer
|
||||
|
||||
|
||||
def load_flat(checkpoint_path):
|
||||
|
|
|
@ -83,7 +83,11 @@ class TensorBucket:
|
|||
unflat_buffers = list(map(list, zip(*unflat_buffers)))
|
||||
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
|
||||
write_back_tensor = self._write_back_pairs[tensor]
|
||||
write_back_tensor.data.copy_(
|
||||
_flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()].reshape_as(write_back_tensor)
|
||||
)
|
||||
rec_tensor = _flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()]
|
||||
if write_back_tensor.is_contiguous():
|
||||
rec_tensor = rec_tensor.view_as(write_back_tensor)
|
||||
else:
|
||||
rec_tensor = rec_tensor.reshape_as(write_back_tensor)
|
||||
write_back_tensor.data.copy_(rec_tensor)
|
||||
|
||||
self.empty()
|
||||
|
|
|
@ -3,18 +3,12 @@ import tempfile
|
|||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from colossalai.testing import check_state_dict_equal, clear_cache_before_run
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
|
||||
|
||||
try:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
||||
|
||||
|
||||
from colossalai.testing import check_state_dict_equal
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def test_save_load():
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
optimizer_state_dict = {
|
||||
|
@ -111,8 +105,7 @@ def test_save_load():
|
|||
}
|
||||
|
||||
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
|
||||
f_writer = AsyncFileWriter(optimizer_saved_path, n_entries=191, backend="pthread")
|
||||
save_nested(f_writer, optimizer_state_dict)
|
||||
f_writer = save_nested(optimizer_saved_path, optimizer_state_dict)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
del f_writer
|
||||
|
@ -120,8 +113,7 @@ def test_save_load():
|
|||
check_state_dict_equal(load_state_dict, optimizer_state_dict)
|
||||
|
||||
optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
|
||||
f_writer = AsyncFileWriter(optimizer_shard_saved_path, n_entries=191, backend="pthread")
|
||||
save_nested(f_writer, optimizer_state_dict["state"])
|
||||
f_writer = save_nested(optimizer_shard_saved_path, optimizer_state_dict["state"])
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
del f_writer
|
||||
|
@ -134,8 +126,7 @@ def test_save_load():
|
|||
"module.weight2": torch.rand((1024, 1024)),
|
||||
}
|
||||
model_saved_path = f"{tempdir}/save_model.safetensors"
|
||||
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
|
||||
save(f_writer, model_state_dict)
|
||||
f_writer = save(model_saved_path, model_state_dict)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
del f_writer
|
||||
|
@ -145,8 +136,7 @@ def test_save_load():
|
|||
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
|
||||
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
|
||||
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
|
||||
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
|
||||
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
|
||||
f_writer = move_and_save(model_saved_path, model_state_dict_cuda, model_state_pinned)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
del f_writer
|
||||
|
|
|
@ -10,7 +10,7 @@ from colossalai.logging import disable_existing_loggers
|
|||
from colossalai.nn.optimizer import DistributedLamb, Lamb
|
||||
from colossalai.tensor.d_tensor import get_shard_dim_1d, is_distributed_tensor
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
@ -108,6 +108,7 @@ def set_dist_grad(
|
|||
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
|
||||
@parameterize("bias_correction", [False, True])
|
||||
@parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)])
|
||||
@clear_cache_before_run()
|
||||
def run_dist_lamb_basic(
|
||||
bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
|
||||
) -> None:
|
||||
|
@ -177,6 +178,7 @@ def run_dist_lamb_basic(
|
|||
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
|
||||
@parameterize("bias_correction", [False, True])
|
||||
@parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)])
|
||||
@clear_cache_before_run()
|
||||
def run_dist_lamb_fwd_bwd(
|
||||
bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
|
||||
) -> None:
|
||||
|
|
Loading…
Reference in New Issue