mirror of https://github.com/hpcaitech/ColossalAI
[checkpointio] support debug log (#6153)
* [checkpointio] support debug log * [checkpointio] refactor async writer api * fix test * fix testpull/6155/head^2
parent
ab856fd308
commit
6280cb18b8
|
@ -137,12 +137,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
state_dict = optimizer.state_dict(pinned_state_dicts, only_on_master=True)
|
state_dict = optimizer.state_dict(pinned_state_dicts, only_on_master=True)
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
if use_async:
|
if use_async:
|
||||||
from tensornvme.async_file_io import AsyncFileWriter
|
|
||||||
|
|
||||||
from colossalai.utils.safetensors import save_nested
|
from colossalai.utils.safetensors import save_nested
|
||||||
|
|
||||||
f_writer = AsyncFileWriter(checkpoint, n_entries=self.N_WRITE_ENTRIES, backend="pthread")
|
f_writer = save_nested(checkpoint, state_dict)
|
||||||
save_nested(f_writer, state_dict)
|
|
||||||
self.async_writers.append(f_writer)
|
self.async_writers.append(f_writer)
|
||||||
else:
|
else:
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||||
|
@ -222,16 +220,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
if use_async:
|
if use_async:
|
||||||
from tensornvme.async_file_io import AsyncFileWriter
|
|
||||||
|
|
||||||
from colossalai.utils.safetensors import save_nested
|
from colossalai.utils.safetensors import save_nested
|
||||||
|
|
||||||
f_writer = AsyncFileWriter(
|
f_writer = save_nested(checkpoint_file_path, shard)
|
||||||
checkpoint_file_path,
|
|
||||||
n_entries=self.N_WRITE_ENTRIES,
|
|
||||||
backend="pthread",
|
|
||||||
)
|
|
||||||
save_nested(f_writer, shard)
|
|
||||||
self.async_writers.append(f_writer)
|
self.async_writers.append(f_writer)
|
||||||
else:
|
else:
|
||||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||||
|
|
|
@ -59,8 +59,6 @@ class CheckpointIO(ABC):
|
||||||
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
|
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
N_WRITE_ENTRIES: int = 32
|
|
||||||
|
|
||||||
# ======================================
|
# ======================================
|
||||||
# Public methods
|
# Public methods
|
||||||
# ======================================
|
# ======================================
|
||||||
|
|
|
@ -54,13 +54,11 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if use_async:
|
if use_async:
|
||||||
from tensornvme.async_file_io import AsyncFileWriter
|
|
||||||
|
|
||||||
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
|
|
||||||
if id(model) not in self.pinned_state_dicts:
|
if id(model) not in self.pinned_state_dicts:
|
||||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||||
|
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
|
||||||
self.async_writers.append(writer)
|
self.async_writers.append(writer)
|
||||||
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# save the checkpoint
|
# save the checkpoint
|
||||||
|
@ -196,7 +194,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
base_filename=weights_name,
|
base_filename=weights_name,
|
||||||
is_master=True,
|
is_master=True,
|
||||||
pinned_state_dict=pinned_state_dict,
|
pinned_state_dict=pinned_state_dict,
|
||||||
n_write_entries=self.N_WRITE_ENTRIES,
|
|
||||||
)
|
)
|
||||||
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
|
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
|
||||||
self.async_writers.extend(writers)
|
self.async_writers.extend(writers)
|
||||||
|
|
|
@ -686,15 +686,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
for _state_dict in state_dict_list:
|
for _state_dict in state_dict_list:
|
||||||
complete_state_dict.update(_state_dict)
|
complete_state_dict.update(_state_dict)
|
||||||
if use_async:
|
if use_async:
|
||||||
from tensornvme.async_file_io import AsyncFileWriter
|
|
||||||
|
|
||||||
from colossalai.utils.safetensors import move_and_save
|
from colossalai.utils.safetensors import move_and_save
|
||||||
|
|
||||||
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
|
|
||||||
if id(model) not in self.pinned_state_dicts:
|
if id(model) not in self.pinned_state_dicts:
|
||||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||||
|
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
|
||||||
self.async_writers.append(writer)
|
self.async_writers.append(writer)
|
||||||
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
|
|
||||||
else:
|
else:
|
||||||
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
||||||
|
|
||||||
|
|
|
@ -273,7 +273,6 @@ def async_save_state_dict_shards(
|
||||||
base_filename: str,
|
base_filename: str,
|
||||||
is_master: bool,
|
is_master: bool,
|
||||||
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
|
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
|
||||||
n_write_entries: int,
|
|
||||||
use_pp_format: bool = False,
|
use_pp_format: bool = False,
|
||||||
) -> Tuple[int, Dict[str, torch.Tensor], list]:
|
) -> Tuple[int, Dict[str, torch.Tensor], list]:
|
||||||
"""
|
"""
|
||||||
|
@ -290,7 +289,6 @@ def async_save_state_dict_shards(
|
||||||
Returns:
|
Returns:
|
||||||
int: the total size of shards
|
int: the total size of shards
|
||||||
"""
|
"""
|
||||||
from tensornvme.async_file_io import AsyncFileWriter
|
|
||||||
|
|
||||||
total_size = 0
|
total_size = 0
|
||||||
shard_filenames = []
|
shard_filenames = []
|
||||||
|
@ -311,9 +309,6 @@ def async_save_state_dict_shards(
|
||||||
index_file.append_weight_map(key, shard_file)
|
index_file.append_weight_map(key, shard_file)
|
||||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||||
|
|
||||||
writer = AsyncFileWriter(checkpoint_file_path, n_write_entries, backend="pthread")
|
|
||||||
writers.append(writer)
|
|
||||||
|
|
||||||
if pinned_state_dict is not None:
|
if pinned_state_dict is not None:
|
||||||
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()}
|
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()}
|
||||||
else:
|
else:
|
||||||
|
@ -321,7 +316,8 @@ def async_save_state_dict_shards(
|
||||||
returned_state_dict.update(sub_pinned_state_dict)
|
returned_state_dict.update(sub_pinned_state_dict)
|
||||||
|
|
||||||
# Only save on master rank.
|
# Only save on master rank.
|
||||||
move_and_save(writer, shard, sub_pinned_state_dict)
|
writer = move_and_save(checkpoint_file_path, shard, sub_pinned_state_dict)
|
||||||
|
writers.append(writer)
|
||||||
shard_filenames.append(shard_file)
|
shard_filenames.append(shard_file)
|
||||||
del shard
|
del shard
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,8 @@ import io
|
||||||
|
|
||||||
from torch.distributed.distributed_c10d import _pickler, _unpickler
|
from torch.distributed.distributed_c10d import _pickler, _unpickler
|
||||||
|
|
||||||
|
ASYNC_WRITE_ENTRIES = 32
|
||||||
|
|
||||||
|
|
||||||
def _object_to_tensor(obj, device):
|
def _object_to_tensor(obj, device):
|
||||||
f = io.BytesIO()
|
f = io.BytesIO()
|
||||||
|
@ -149,32 +151,31 @@ def prepare(
|
||||||
return PreparedData(n=n, header_bytes=header_buf, offset=offset), tensors, tensor_keys
|
return PreparedData(n=n, header_bytes=header_buf, offset=offset), tensors, tensor_keys
|
||||||
|
|
||||||
|
|
||||||
def save(
|
def save(path: str, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None) -> None:
|
||||||
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
|
|
||||||
) -> None:
|
|
||||||
prepared_data, tensors, _ = prepare(state_dict, metadata)
|
prepared_data, tensors, _ = prepare(state_dict, metadata)
|
||||||
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
||||||
|
f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensors))
|
||||||
f_writer.write(n.to_bytes(8, byteorder="little"))
|
f_writer.write(n.to_bytes(8, byteorder="little"))
|
||||||
f_writer.write(header_bytes)
|
f_writer.write(header_bytes)
|
||||||
|
|
||||||
for tensor in tensors:
|
for tensor in tensors:
|
||||||
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
|
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
|
||||||
|
return f_writer
|
||||||
|
|
||||||
|
|
||||||
def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
|
def save_nested(path: str, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||||
flatten_data, metadata = _flatten_optim_state_dict(state_dict)
|
flatten_data, metadata = _flatten_optim_state_dict(state_dict)
|
||||||
save(f_writer, flatten_data, metadata)
|
return save(path, flatten_data, metadata)
|
||||||
|
|
||||||
|
|
||||||
def move_and_save(
|
def move_and_save(
|
||||||
f_writer: AsyncFileWriter,
|
path: str,
|
||||||
state_dict: Dict[str, torch.Tensor],
|
state_dict: Dict[str, torch.Tensor],
|
||||||
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
|
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
prepared_data, _, tensor_keys = prepare(state_dict)
|
prepared_data, _, tensor_keys = prepare(state_dict)
|
||||||
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
||||||
|
f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensor_keys))
|
||||||
f_writer.write(n.to_bytes(8, byteorder="little"))
|
f_writer.write(n.to_bytes(8, byteorder="little"))
|
||||||
f_writer.write(header_bytes)
|
f_writer.write(header_bytes)
|
||||||
|
|
||||||
|
@ -184,6 +185,7 @@ def move_and_save(
|
||||||
f_writer.write_tensor(state_dict[name], state_dict_pinned[name])
|
f_writer.write_tensor(state_dict[name], state_dict_pinned[name])
|
||||||
else:
|
else:
|
||||||
f_writer.write_tensor(state_dict[name])
|
f_writer.write_tensor(state_dict[name])
|
||||||
|
return f_writer
|
||||||
|
|
||||||
|
|
||||||
def load_flat(checkpoint_path):
|
def load_flat(checkpoint_path):
|
||||||
|
|
|
@ -83,7 +83,11 @@ class TensorBucket:
|
||||||
unflat_buffers = list(map(list, zip(*unflat_buffers)))
|
unflat_buffers = list(map(list, zip(*unflat_buffers)))
|
||||||
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
|
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
|
||||||
write_back_tensor = self._write_back_pairs[tensor]
|
write_back_tensor = self._write_back_pairs[tensor]
|
||||||
write_back_tensor.data.copy_(
|
rec_tensor = _flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()]
|
||||||
_flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()].reshape_as(write_back_tensor)
|
if write_back_tensor.is_contiguous():
|
||||||
)
|
rec_tensor = rec_tensor.view_as(write_back_tensor)
|
||||||
|
else:
|
||||||
|
rec_tensor = rec_tensor.reshape_as(write_back_tensor)
|
||||||
|
write_back_tensor.data.copy_(rec_tensor)
|
||||||
|
|
||||||
self.empty()
|
self.empty()
|
||||||
|
|
|
@ -3,18 +3,12 @@ import tempfile
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
from colossalai.testing import check_state_dict_equal, clear_cache_before_run
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
|
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
|
||||||
|
|
||||||
try:
|
|
||||||
from tensornvme.async_file_io import AsyncFileWriter
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
|
||||||
|
|
||||||
|
|
||||||
from colossalai.testing import check_state_dict_equal
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
|
|
||||||
|
|
||||||
|
@clear_cache_before_run()
|
||||||
def test_save_load():
|
def test_save_load():
|
||||||
with tempfile.TemporaryDirectory() as tempdir:
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
optimizer_state_dict = {
|
optimizer_state_dict = {
|
||||||
|
@ -111,8 +105,7 @@ def test_save_load():
|
||||||
}
|
}
|
||||||
|
|
||||||
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
|
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
|
||||||
f_writer = AsyncFileWriter(optimizer_saved_path, n_entries=191, backend="pthread")
|
f_writer = save_nested(optimizer_saved_path, optimizer_state_dict)
|
||||||
save_nested(f_writer, optimizer_state_dict)
|
|
||||||
f_writer.sync_before_step()
|
f_writer.sync_before_step()
|
||||||
f_writer.synchronize()
|
f_writer.synchronize()
|
||||||
del f_writer
|
del f_writer
|
||||||
|
@ -120,8 +113,7 @@ def test_save_load():
|
||||||
check_state_dict_equal(load_state_dict, optimizer_state_dict)
|
check_state_dict_equal(load_state_dict, optimizer_state_dict)
|
||||||
|
|
||||||
optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
|
optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
|
||||||
f_writer = AsyncFileWriter(optimizer_shard_saved_path, n_entries=191, backend="pthread")
|
f_writer = save_nested(optimizer_shard_saved_path, optimizer_state_dict["state"])
|
||||||
save_nested(f_writer, optimizer_state_dict["state"])
|
|
||||||
f_writer.sync_before_step()
|
f_writer.sync_before_step()
|
||||||
f_writer.synchronize()
|
f_writer.synchronize()
|
||||||
del f_writer
|
del f_writer
|
||||||
|
@ -134,8 +126,7 @@ def test_save_load():
|
||||||
"module.weight2": torch.rand((1024, 1024)),
|
"module.weight2": torch.rand((1024, 1024)),
|
||||||
}
|
}
|
||||||
model_saved_path = f"{tempdir}/save_model.safetensors"
|
model_saved_path = f"{tempdir}/save_model.safetensors"
|
||||||
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
|
f_writer = save(model_saved_path, model_state_dict)
|
||||||
save(f_writer, model_state_dict)
|
|
||||||
f_writer.sync_before_step()
|
f_writer.sync_before_step()
|
||||||
f_writer.synchronize()
|
f_writer.synchronize()
|
||||||
del f_writer
|
del f_writer
|
||||||
|
@ -145,8 +136,7 @@ def test_save_load():
|
||||||
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
|
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
|
||||||
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
|
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
|
||||||
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
|
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
|
||||||
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
|
f_writer = move_and_save(model_saved_path, model_state_dict_cuda, model_state_pinned)
|
||||||
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
|
|
||||||
f_writer.sync_before_step()
|
f_writer.sync_before_step()
|
||||||
f_writer.synchronize()
|
f_writer.synchronize()
|
||||||
del f_writer
|
del f_writer
|
||||||
|
|
|
@ -10,7 +10,7 @@ from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.nn.optimizer import DistributedLamb, Lamb
|
from colossalai.nn.optimizer import DistributedLamb, Lamb
|
||||||
from colossalai.tensor.d_tensor import get_shard_dim_1d, is_distributed_tensor
|
from colossalai.tensor.d_tensor import get_shard_dim_1d, is_distributed_tensor
|
||||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
from colossalai.zero import LowLevelZeroOptimizer
|
from colossalai.zero import LowLevelZeroOptimizer
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
@ -108,6 +108,7 @@ def set_dist_grad(
|
||||||
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
|
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
|
||||||
@parameterize("bias_correction", [False, True])
|
@parameterize("bias_correction", [False, True])
|
||||||
@parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)])
|
@parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)])
|
||||||
|
@clear_cache_before_run()
|
||||||
def run_dist_lamb_basic(
|
def run_dist_lamb_basic(
|
||||||
bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
|
bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -177,6 +178,7 @@ def run_dist_lamb_basic(
|
||||||
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
|
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
|
||||||
@parameterize("bias_correction", [False, True])
|
@parameterize("bias_correction", [False, True])
|
||||||
@parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)])
|
@parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)])
|
||||||
|
@clear_cache_before_run()
|
||||||
def run_dist_lamb_fwd_bwd(
|
def run_dist_lamb_fwd_bwd(
|
||||||
bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
|
bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
Loading…
Reference in New Issue