[async io]supoort async io (#6137)

* support async optimizer save/load

* fix

* fix

* support pin mem

* Update low_level_zero_plugin.py

* fix

* fix

* fix

* fix

* fix
pull/6124/merge
flybird11111 7 days ago committed by Hongxin Liu
parent b90835bd32
commit eb69e640e5

@ -359,6 +359,7 @@ class Booster:
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
) -> None:
"""
Save optimizer to checkpoint.
@ -374,7 +375,9 @@ class Booster:
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
self.checkpoint_io.save_optimizer(
optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard, use_async=use_async
)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
"""Save lr scheduler to checkpoint.

@ -94,7 +94,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
@ -178,7 +180,13 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
def save_sharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
self,
optimizer: GeminiOptimizer,
checkpoint: Path,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save sharded optimizer state dict to checkpoint folder.

@ -24,6 +24,7 @@ from colossalai.checkpoint_io.utils import (
get_shard_filename,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict,
load_states_into_optimizer,
save_param_groups,
save_state_dict,
@ -113,7 +114,9 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False, use_async: bool = False
):
"""Save optimizer to checkpoint but only on master process.
Args:
@ -125,9 +128,34 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
# the `state_dict` in LowLevelZeroOptimizer has communication
# if only the master rank collect state_dict and save,
# the communication on each rank would not match
state_dict = optimizer.state_dict()
if use_async:
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
state_dict = optimizer.state_dict(pinned_state_dicts)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter
from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]})
self.async_writers.append(f_writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
use_async = checkpoint.endswith(".safetensors")
if use_async:
from colossalai.utils.safetensors import load_flat
checkpoint = load_flat(checkpoint)
else:
checkpoint = load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)
def save_sharded_optimizer(
self,
@ -136,6 +164,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
gather_dtensor: bool = False,
prefix: str = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
@ -161,10 +190,16 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
# state_dict only provide only 'param_groups'
state_dict = optimizer.optim.state_dict()
# state shard would be handled by the low-level zero optimizer
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)
if use_async:
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts)
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
index_file = CheckpointIndexFile(checkpoint)
index_file.append_meta_data("param_groups", param_group_file)
@ -184,7 +219,18 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
checkpoint_file_path = os.path.join(checkpoint, shard_file)
if self.coordinator.is_master():
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter
from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
save_nested(f_writer, shard)
self.async_writers.append(f_writer)
else:
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
@ -223,7 +269,12 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
if shard_file.endswith(".safetensors"):
from colossalai.utils.safetensors import load_flat
state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
# shard state dict
for param_idx, state in state_dict.items():
for k, v in state.items():

@ -52,7 +52,9 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer to checkpoint but only on master process.
"""
@ -113,13 +115,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save optimizer to sharded checkpoint but only on master process.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
super().save_sharded_optimizer(
optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
)
def load_sharded_optimizer(
self,

@ -67,7 +67,9 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
full_model_state = model.state_dict()
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer to checkpoint but only on master process.
"""
@ -157,7 +159,13 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
self,
optimizer: Optimizer,
checkpoint: str,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save optimizer to checkpoint but only on master process.

@ -213,6 +213,7 @@ class CheckpointIO(ABC):
gather_dtensor=True,
prefix: str = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
@ -229,11 +230,12 @@ class CheckpointIO(ABC):
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
"""
if shard:
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
self.save_sharded_optimizer(
optimizer, checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
)
else:
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)
# ========================================================
# Abstract methods for model loading/saving implementation
@ -326,7 +328,13 @@ class CheckpointIO(ABC):
@abstractmethod
def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
self,
optimizer: Optimizer,
checkpoint: Path,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save optimizer to sharded checkpoint.
@ -340,7 +348,9 @@ class CheckpointIO(ABC):
"""
@abstractmethod
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer to unsharded checkpoint.

@ -98,6 +98,7 @@ class GeneralCheckpointIO(CheckpointIO):
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
@ -155,6 +156,7 @@ class GeneralCheckpointIO(CheckpointIO):
optimizer: Optimizer,
checkpoint: Path,
gather_dtensor: bool,
use_async: bool = False,
):
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)

@ -416,6 +416,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
@ -725,7 +726,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Update master params if mixed-precision training is enabled.
model_before_wrapping.update_master_params()
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer state dict to a file with given path.

@ -369,6 +369,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
@ -729,7 +730,13 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
dist.barrier()
# Copied from colossalai.moe
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool,
use_async: bool = False,
):
"""
Save optimizer state dict to a file with given path.

@ -24,9 +24,11 @@ from colossalai.utils.safetensors import move_and_save
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
STATES_NAME = "pytorch_optim.bin"
SAFE_STATE_NAME = "optimizer.safetensors"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
SAFE_STATES_INDEX_NAME = "optimizer.safetensors.index.json"
GROUP_FILE_NAME = "pytorch_optim_group.bin"
# ======================================
@ -838,14 +840,14 @@ def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
return weights_name, save_index_file
def get_optimizer_base_filenames(prefix: str = None):
def get_optimizer_base_filenames(prefix: str = None, use_safetensors: bool = False):
"""
generate base optimizer state filenames
"""
states_name = STATES_NAME
states_name = SAFE_STATE_NAME if use_safetensors else STATES_NAME
states_name = add_prefix(states_name, prefix)
save_index_file = STATES_INDEX_NAME
save_index_file = SAFE_STATES_INDEX_NAME if use_safetensors else STATES_INDEX_NAME
save_index_file = add_prefix(save_index_file, prefix)
param_group_file = GROUP_FILE_NAME

@ -1,4 +1,4 @@
from typing import Any, List, OrderedDict
from typing import Any, List, OrderedDict, Tuple
import torch
import torch.distributed as dist
@ -78,7 +78,9 @@ def check_state_dict_equal(
v1 = v1.to(v2.dtype)
assert_close_loose(v1, v2)
else:
assert v1 == v2, f"{v1} not equals to {v2}"
if isinstance(v1, Tuple) and not isinstance(v2, Tuple):
v2 = tuple(v2)
assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}"
def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):

@ -1,10 +1,11 @@
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
import json
import warnings
from dataclasses import asdict, dataclass
from typing import Dict, List, Optional, Tuple
import torch
from safetensors.torch import _TYPES
from safetensors.torch import _TYPES, load_file, safe_open
try:
from tensornvme.async_file_io import AsyncFileWriter
@ -27,34 +28,93 @@ class PreparedData:
offset: int
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor], List[str]]:
def flatten_dict(nested_dict, parent_key="", separator="^"):
"""
Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator.
nested_dict: The input nested dictionary.
parent_key: The parent key currently being processed.
separator: The separator used to join keys, default is '_', but can be customized to another symbol. :return: A flattened dictionary."
"""
items = []
for k, v in nested_dict.items():
new_key = f"{parent_key}{separator}{k}" if parent_key else str(k)
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, separator).items())
else:
v = torch.tensor(v, dtype=torch.float16) if not isinstance(v, torch.Tensor) else v
items.append((new_key, v))
return dict(items)
def unflatten_dict(flattened_dict, separator="^"):
"""
Restore a flattened dictionary back to a multi-level nested dictionary.
flattened_dict: The flattened dictionary.
separator: The separator used during flattening, default is '_', but can be customized to another symbol. :return: The restored nested dictionary.
"""
nested_dict = {}
for key, value in flattened_dict.items():
keys = key.split(separator)
try:
keys[0] = int(keys[0])
except ValueError:
warnings.warn(f"{key[0]} can't convert to integer")
d = nested_dict
for part in keys[:-1]:
if part not in d:
d[part] = {}
d = d[part]
assert isinstance(value, torch.Tensor)
d[keys[-1]] = value
return nested_dict
def prepare(
data: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
) -> Tuple[PreparedData, List[torch.Tensor], List[str]]:
if metadata is not None:
assert isinstance(metadata, dict)
for k, v in metadata.items():
metadata[k] = json.dumps(v)
assert isinstance(k, str)
assert isinstance(metadata[k], str)
tensors = []
tensor_keys = []
metadata = {}
header = {}
offset = 0
if metadata is not None:
header["__metadata__"] = metadata
for name, tensor in data.items():
n = tensor.numel() * tensor.element_size()
tensor_info = TensorInfo(
dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n)
)
offset += n
metadata[name] = asdict(tensor_info)
header[name] = asdict(tensor_info)
tensors.append(tensor)
tensor_keys.append(name)
metadata_buf = json.dumps(metadata).encode("utf-8")
header_buf = json.dumps(header).encode("utf-8")
extra = (8 - len(metadata_buf) % 8) % 8
metadata_buf += b" " * extra
extra = (8 - len(header_buf) % 8) % 8
header_buf += b" " * extra
n = len(metadata_buf)
n = len(header_buf)
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors, tensor_keys
return PreparedData(n=n, header_bytes=header_buf, offset=offset), tensors, tensor_keys
def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
prepared_data, tensors, _ = prepare(state_dict)
def save(
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
) -> None:
prepared_data, tensors, _ = prepare(state_dict, metadata)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
f_writer.write(n.to_bytes(8, byteorder="little"))
@ -64,6 +124,13 @@ def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
def save_nested(
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
) -> None:
flatten_data = flatten_dict(state_dict)
save(f_writer, flatten_data, metadata)
def move_and_save(
f_writer: AsyncFileWriter,
state_dict: Dict[str, torch.Tensor],
@ -81,3 +148,16 @@ def move_and_save(
f_writer.write_tensor(state_dict[name], state_dict_pinned[name])
else:
f_writer.write_tensor(state_dict[name])
def load_flat(checkpoint_path):
with safe_open(checkpoint_path, framework="pt") as f:
metadata = f.metadata()
state_dict_load = load_file(checkpoint_path)
state_dict = unflatten_dict(state_dict_load)
if metadata is None:
return state_dict
metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items()))
combined_state_dict = {"state": state_dict}
combined_state_dict.update(metadata)
return combined_state_dict

@ -770,7 +770,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
return {"state": packed_state, "param_groups": param_groups}
def state_dict(self) -> Dict:
def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None) -> Dict:
"""Return a state_dict same with DDP
Returns:
@ -779,15 +779,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
zero_state = dict()
device = get_accelerator().get_current_device()
for param, state in self.optim.state.items():
if pinned_state_dicts and param not in pinned_state_dicts:
pinned_state_dicts[param] = {}
zero_state[param] = copy.deepcopy(state)
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
if pinned_state_dicts and k not in pinned_state_dicts[param]:
pinned_state_dicts[param][k] = torch.empty_like(working_param, pin_memory=True, device="cpu")
working_param = self.master_to_working_param[id(param)]
pg = self.param_to_pg[working_param]
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param).cpu()
zero_state[param][k] = param_state
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
if pinned_state_dicts:
pinned_state_dicts[param][k].copy_(param_state)
zero_state[param][k] = pinned_state_dicts[param][k]
else:
zero_state[param][k] = param_state.cpu()
states_dict = self._pack_state(zero_state)
@ -822,7 +830,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.optim.load_state_dict(zero_state_dict)
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
def state_dict_shard(
self, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None
) -> Iterator[Tuple[Dict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
Only include the 'state' in state_dict.
@ -847,18 +857,27 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for param_idx, states in local_states.items():
current_block_size = 0
current_block = copy.deepcopy(states)
if pinned_state_dicts and param_idx not in pinned_state_dicts:
pinned_state_dicts[param_idx] = {}
master_param = idx2master[param_idx]
working_param = self.master_to_working_param[id(master_param)]
pg = self.param_to_pg[working_param]
for k, v in states.items():
if isinstance(v, torch.Tensor) and k != "step":
if pinned_state_dicts and k not in pinned_state_dicts[param_idx]:
pinned_state_dicts[param_idx][k] = torch.empty_like(
working_param, pin_memory=True, device="cpu"
)
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param).cpu()
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
if pinned_state_dicts:
pinned_state_dicts[param_idx][k].copy_(state_tensor)
current_block[k] = pinned_state_dicts[param_idx][k]
else:
current_block[k] = state_tensor.cpu()
current_block_size += state_tensor.numel()
current_block[k] = state_tensor
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
yield ret_block, ret_block_size

@ -51,6 +51,8 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
model_ckpt_path = f"{model_ckpt_path}.pt"
if not shard and use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
if not shard and use_async:
optimizer_ckpt_path = f"{tempdir}/optimizer.safetensors"
booster.save_model(
model,
model_ckpt_path,
@ -59,7 +61,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
)
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, use_async=use_async)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier()
@ -139,7 +141,6 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
assert torch.equal(
working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)
)
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())

@ -0,0 +1,127 @@
import tempfile
from copy import deepcopy
import torch
from colossalai.utils.safetensors import load_flat, save_nested
try:
from tensornvme.async_file_io import AsyncFileWriter
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
from colossalai.testing import check_state_dict_equal
def test_save_load():
with tempfile.TemporaryDirectory() as tempdir:
optimizer_state_dict = {
0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
}
# group_dict = {"param_groups": [0, 1, 2]}
group_dict = {
"param_groups": [
{
"lr": 0.001,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0,
"bias_correction": True,
"params": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
61,
],
}
]
}
metadata = deepcopy(group_dict)
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict, metadata)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
load_state_dict = load_flat(optimizer_saved_path)
state_dict = load_state_dict["state"]
group = {"param_groups": load_state_dict["param_groups"]}
check_state_dict_equal(optimizer_state_dict, state_dict)
check_state_dict_equal(group_dict, group)
model_state_dict = {
"module.weight0": torch.rand((1024, 1024)),
"module.weight1": torch.rand((1024, 1024)),
"module.weight2": torch.rand((1024, 1024)),
}
model_saved_path = f"{tempdir}/save_model.safetensors"
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, model_state_dict)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
load_state_dict = load_flat(model_saved_path)
check_state_dict_equal(model_state_dict, load_state_dict)
Loading…
Cancel
Save