[async io]supoort async io (#6137)

* support async optimizer save/load

* fix

* fix

* support pin mem

* Update low_level_zero_plugin.py

* fix

* fix

* fix

* fix

* fix
pull/6147/head
flybird11111 2024-11-18 17:52:24 +08:00 committed by Hongxin Liu
parent b90835bd32
commit eb69e640e5
15 changed files with 374 additions and 46 deletions

View File

@ -359,6 +359,7 @@ class Booster:
gather_dtensor: bool = True, gather_dtensor: bool = True,
prefix: Optional[str] = None, prefix: Optional[str] = None,
size_per_shard: int = 1024, size_per_shard: int = 1024,
use_async: bool = False,
) -> None: ) -> None:
""" """
Save optimizer to checkpoint. Save optimizer to checkpoint.
@ -374,7 +375,9 @@ class Booster:
names to compose the keys in state_dict. Defaults to None. names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
""" """
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard) self.checkpoint_io.save_optimizer(
optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard, use_async=use_async
)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None: def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
"""Save lr scheduler to checkpoint. """Save lr scheduler to checkpoint.

View File

@ -94,7 +94,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
assert isinstance(model, GeminiDDP), "Please boost the model before loading!" assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict=strict) super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool): def save_unsharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
""" """
Save unsharded optimizer state dict to checkpoint. Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank. After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
@ -178,7 +180,13 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
def save_sharded_optimizer( def save_sharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int self,
optimizer: GeminiOptimizer,
checkpoint: Path,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
): ):
""" """
Save sharded optimizer state dict to checkpoint folder. Save sharded optimizer state dict to checkpoint folder.

View File

@ -24,6 +24,7 @@ from colossalai.checkpoint_io.utils import (
get_shard_filename, get_shard_filename,
load_param_groups_into_optimizer, load_param_groups_into_optimizer,
load_shard_state_dict, load_shard_state_dict,
load_state_dict,
load_states_into_optimizer, load_states_into_optimizer,
save_param_groups, save_param_groups,
save_state_dict, save_state_dict,
@ -113,7 +114,9 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False, use_async: bool = False
):
"""Save optimizer to checkpoint but only on master process. """Save optimizer to checkpoint but only on master process.
Args: Args:
@ -125,10 +128,35 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
# the `state_dict` in LowLevelZeroOptimizer has communication # the `state_dict` in LowLevelZeroOptimizer has communication
# if only the master rank collect state_dict and save, # if only the master rank collect state_dict and save,
# the communication on each rank would not match # the communication on each rank would not match
state_dict = optimizer.state_dict() if use_async:
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
state_dict = optimizer.state_dict(pinned_state_dicts)
if self.coordinator.is_master(): if self.coordinator.is_master():
if use_async:
from tensornvme.async_file_io import AsyncFileWriter
from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]})
self.async_writers.append(f_writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False) save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
use_async = checkpoint.endswith(".safetensors")
if use_async:
from colossalai.utils.safetensors import load_flat
checkpoint = load_flat(checkpoint)
else:
checkpoint = load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)
def save_sharded_optimizer( def save_sharded_optimizer(
self, self,
optimizer: OptimizerWrapper, optimizer: OptimizerWrapper,
@ -136,6 +164,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
gather_dtensor: bool = False, gather_dtensor: bool = False,
prefix: str = None, prefix: str = None,
size_per_shard: int = 1024, size_per_shard: int = 1024,
use_async: bool = False,
): ):
""" """
Save sharded Zero-optimizer checkpoint under the given checkpointing path. Save sharded Zero-optimizer checkpoint under the given checkpointing path.
@ -161,10 +190,16 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
# state_dict only provide only 'param_groups' # state_dict only provide only 'param_groups'
state_dict = optimizer.optim.state_dict() state_dict = optimizer.optim.state_dict()
# state shard would be handled by the low-level zero optimizer # state shard would be handled by the low-level zero optimizer
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard) if use_async:
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts)
# Preparing file paths and index file. # Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
index_file = CheckpointIndexFile(checkpoint) index_file = CheckpointIndexFile(checkpoint)
index_file.append_meta_data("param_groups", param_group_file) index_file.append_meta_data("param_groups", param_group_file)
@ -184,6 +219,17 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
checkpoint_file_path = os.path.join(checkpoint, shard_file) checkpoint_file_path = os.path.join(checkpoint, shard_file)
if self.coordinator.is_master(): if self.coordinator.is_master():
if use_async:
from tensornvme.async_file_io import AsyncFileWriter
from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
save_nested(f_writer, shard)
self.async_writers.append(f_writer)
else:
save_state_dict(shard, checkpoint_file_path, use_safetensors=False) save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
# Wrap up index file. # Wrap up index file.
@ -223,6 +269,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files: for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
from colossalai.utils.safetensors import load_flat
state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
# shard state dict # shard state dict
for param_idx, state in state_dict.items(): for param_idx, state in state_dict.items():

View File

@ -52,7 +52,9 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint) super().load_unsharded_optimizer(optimizer, checkpoint)
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
""" """
Save optimizer to checkpoint but only on master process. Save optimizer to checkpoint but only on master process.
""" """
@ -113,13 +115,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
gather_dtensor: bool = True, gather_dtensor: bool = True,
prefix: Optional[str] = None, prefix: Optional[str] = None,
size_per_shard: int = 1024, size_per_shard: int = 1024,
use_async: bool = False,
): ):
""" """
Save optimizer to sharded checkpoint but only on master process. Save optimizer to sharded checkpoint but only on master process.
""" """
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master(): if self.coordinator.is_master():
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard) super().save_sharded_optimizer(
optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
)
def load_sharded_optimizer( def load_sharded_optimizer(
self, self,

View File

@ -67,7 +67,9 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
full_model_state = model.state_dict() full_model_state = model.state_dict()
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
""" """
Save optimizer to checkpoint but only on master process. Save optimizer to checkpoint but only on master process.
""" """
@ -157,7 +159,13 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
model.unwrap().load_state_dict(fsdp_state_dict, strict=False) model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
def save_sharded_optimizer( def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int self,
optimizer: Optimizer,
checkpoint: str,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
): ):
""" """
Save optimizer to checkpoint but only on master process. Save optimizer to checkpoint but only on master process.

View File

@ -213,6 +213,7 @@ class CheckpointIO(ABC):
gather_dtensor=True, gather_dtensor=True,
prefix: str = None, prefix: str = None,
size_per_shard: int = 1024, size_per_shard: int = 1024,
use_async: bool = False,
): ):
""" """
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors. Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
@ -229,11 +230,12 @@ class CheckpointIO(ABC):
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None. prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
""" """
if shard: if shard:
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard) self.save_sharded_optimizer(
optimizer, checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
)
else: else:
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)
# ======================================================== # ========================================================
# Abstract methods for model loading/saving implementation # Abstract methods for model loading/saving implementation
@ -326,7 +328,13 @@ class CheckpointIO(ABC):
@abstractmethod @abstractmethod
def save_sharded_optimizer( def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int self,
optimizer: Optimizer,
checkpoint: Path,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
): ):
""" """
Save optimizer to sharded checkpoint. Save optimizer to sharded checkpoint.
@ -340,7 +348,9 @@ class CheckpointIO(ABC):
""" """
@abstractmethod @abstractmethod
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool): def save_unsharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, use_async: bool = False
):
""" """
Save optimizer to unsharded checkpoint. Save optimizer to unsharded checkpoint.

View File

@ -98,6 +98,7 @@ class GeneralCheckpointIO(CheckpointIO):
gather_dtensor: bool, gather_dtensor: bool,
prefix: str, prefix: str,
size_per_shard: int, size_per_shard: int,
use_async: bool = False,
): ):
""" """
Save sharded optimizer checkpoint under the given checkpointing path. Save sharded optimizer checkpoint under the given checkpointing path.
@ -155,6 +156,7 @@ class GeneralCheckpointIO(CheckpointIO):
optimizer: Optimizer, optimizer: Optimizer,
checkpoint: Path, checkpoint: Path,
gather_dtensor: bool, gather_dtensor: bool,
use_async: bool = False,
): ):
# TODO(FrankLeeeee): handle distributed tensors # TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)

View File

@ -416,6 +416,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
gather_dtensor: bool = True, gather_dtensor: bool = True,
prefix: Optional[str] = None, prefix: Optional[str] = None,
size_per_shard: int = 1024, size_per_shard: int = 1024,
use_async: bool = False,
): ):
""" """
Save sharded optimizer checkpoint under the given checkpointing path. Save sharded optimizer checkpoint under the given checkpointing path.
@ -725,7 +726,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Update master params if mixed-precision training is enabled. # Update master params if mixed-precision training is enabled.
model_before_wrapping.update_master_params() model_before_wrapping.update_master_params()
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
""" """
Save optimizer state dict to a file with given path. Save optimizer state dict to a file with given path.

View File

@ -369,6 +369,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
gather_dtensor: bool = True, gather_dtensor: bool = True,
prefix: Optional[str] = None, prefix: Optional[str] = None,
size_per_shard: int = 1024, size_per_shard: int = 1024,
use_async: bool = False,
): ):
""" """
Save sharded optimizer checkpoint under the given checkpointing path. Save sharded optimizer checkpoint under the given checkpointing path.
@ -729,7 +730,13 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
dist.barrier() dist.barrier()
# Copied from colossalai.moe # Copied from colossalai.moe
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): def save_unsharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool,
use_async: bool = False,
):
""" """
Save optimizer state dict to a file with given path. Save optimizer state dict to a file with given path.

View File

@ -24,9 +24,11 @@ from colossalai.utils.safetensors import move_and_save
SAFE_WEIGHTS_NAME = "model.safetensors" SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_NAME = "pytorch_model.bin"
STATES_NAME = "pytorch_optim.bin" STATES_NAME = "pytorch_optim.bin"
SAFE_STATE_NAME = "optimizer.safetensors"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
STATES_INDEX_NAME = "pytorch_optim.bin.index.json" STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
SAFE_STATES_INDEX_NAME = "optimizer.safetensors.index.json"
GROUP_FILE_NAME = "pytorch_optim_group.bin" GROUP_FILE_NAME = "pytorch_optim_group.bin"
# ====================================== # ======================================
@ -838,14 +840,14 @@ def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
return weights_name, save_index_file return weights_name, save_index_file
def get_optimizer_base_filenames(prefix: str = None): def get_optimizer_base_filenames(prefix: str = None, use_safetensors: bool = False):
""" """
generate base optimizer state filenames generate base optimizer state filenames
""" """
states_name = STATES_NAME states_name = SAFE_STATE_NAME if use_safetensors else STATES_NAME
states_name = add_prefix(states_name, prefix) states_name = add_prefix(states_name, prefix)
save_index_file = STATES_INDEX_NAME save_index_file = SAFE_STATES_INDEX_NAME if use_safetensors else STATES_INDEX_NAME
save_index_file = add_prefix(save_index_file, prefix) save_index_file = add_prefix(save_index_file, prefix)
param_group_file = GROUP_FILE_NAME param_group_file = GROUP_FILE_NAME

View File

@ -1,4 +1,4 @@
from typing import Any, List, OrderedDict from typing import Any, List, OrderedDict, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -78,7 +78,9 @@ def check_state_dict_equal(
v1 = v1.to(v2.dtype) v1 = v1.to(v2.dtype)
assert_close_loose(v1, v2) assert_close_loose(v1, v2)
else: else:
assert v1 == v2, f"{v1} not equals to {v2}" if isinstance(v1, Tuple) and not isinstance(v2, Tuple):
v2 = tuple(v2)
assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}"
def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):

View File

@ -1,10 +1,11 @@
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214 # a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
import json import json
import warnings
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
from safetensors.torch import _TYPES from safetensors.torch import _TYPES, load_file, safe_open
try: try:
from tensornvme.async_file_io import AsyncFileWriter from tensornvme.async_file_io import AsyncFileWriter
@ -27,34 +28,93 @@ class PreparedData:
offset: int offset: int
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor], List[str]]: def flatten_dict(nested_dict, parent_key="", separator="^"):
"""
Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator.
nested_dict: The input nested dictionary.
parent_key: The parent key currently being processed.
separator: The separator used to join keys, default is '_', but can be customized to another symbol. :return: A flattened dictionary."
"""
items = []
for k, v in nested_dict.items():
new_key = f"{parent_key}{separator}{k}" if parent_key else str(k)
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, separator).items())
else:
v = torch.tensor(v, dtype=torch.float16) if not isinstance(v, torch.Tensor) else v
items.append((new_key, v))
return dict(items)
def unflatten_dict(flattened_dict, separator="^"):
"""
Restore a flattened dictionary back to a multi-level nested dictionary.
flattened_dict: The flattened dictionary.
separator: The separator used during flattening, default is '_', but can be customized to another symbol. :return: The restored nested dictionary.
"""
nested_dict = {}
for key, value in flattened_dict.items():
keys = key.split(separator)
try:
keys[0] = int(keys[0])
except ValueError:
warnings.warn(f"{key[0]} can't convert to integer")
d = nested_dict
for part in keys[:-1]:
if part not in d:
d[part] = {}
d = d[part]
assert isinstance(value, torch.Tensor)
d[keys[-1]] = value
return nested_dict
def prepare(
data: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
) -> Tuple[PreparedData, List[torch.Tensor], List[str]]:
if metadata is not None:
assert isinstance(metadata, dict)
for k, v in metadata.items():
metadata[k] = json.dumps(v)
assert isinstance(k, str)
assert isinstance(metadata[k], str)
tensors = [] tensors = []
tensor_keys = [] tensor_keys = []
metadata = {} header = {}
offset = 0 offset = 0
if metadata is not None:
header["__metadata__"] = metadata
for name, tensor in data.items(): for name, tensor in data.items():
n = tensor.numel() * tensor.element_size() n = tensor.numel() * tensor.element_size()
tensor_info = TensorInfo( tensor_info = TensorInfo(
dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n) dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n)
) )
offset += n offset += n
metadata[name] = asdict(tensor_info) header[name] = asdict(tensor_info)
tensors.append(tensor) tensors.append(tensor)
tensor_keys.append(name) tensor_keys.append(name)
metadata_buf = json.dumps(metadata).encode("utf-8") header_buf = json.dumps(header).encode("utf-8")
extra = (8 - len(metadata_buf) % 8) % 8 extra = (8 - len(header_buf) % 8) % 8
metadata_buf += b" " * extra header_buf += b" " * extra
n = len(metadata_buf) n = len(header_buf)
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors, tensor_keys return PreparedData(n=n, header_bytes=header_buf, offset=offset), tensors, tensor_keys
def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None: def save(
prepared_data, tensors, _ = prepare(state_dict) f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
) -> None:
prepared_data, tensors, _ = prepare(state_dict, metadata)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
f_writer.write(n.to_bytes(8, byteorder="little")) f_writer.write(n.to_bytes(8, byteorder="little"))
@ -64,6 +124,13 @@ def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset) f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
def save_nested(
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
) -> None:
flatten_data = flatten_dict(state_dict)
save(f_writer, flatten_data, metadata)
def move_and_save( def move_and_save(
f_writer: AsyncFileWriter, f_writer: AsyncFileWriter,
state_dict: Dict[str, torch.Tensor], state_dict: Dict[str, torch.Tensor],
@ -81,3 +148,16 @@ def move_and_save(
f_writer.write_tensor(state_dict[name], state_dict_pinned[name]) f_writer.write_tensor(state_dict[name], state_dict_pinned[name])
else: else:
f_writer.write_tensor(state_dict[name]) f_writer.write_tensor(state_dict[name])
def load_flat(checkpoint_path):
with safe_open(checkpoint_path, framework="pt") as f:
metadata = f.metadata()
state_dict_load = load_file(checkpoint_path)
state_dict = unflatten_dict(state_dict_load)
if metadata is None:
return state_dict
metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items()))
combined_state_dict = {"state": state_dict}
combined_state_dict.update(metadata)
return combined_state_dict

View File

@ -770,7 +770,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
return {"state": packed_state, "param_groups": param_groups} return {"state": packed_state, "param_groups": param_groups}
def state_dict(self) -> Dict: def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None) -> Dict:
"""Return a state_dict same with DDP """Return a state_dict same with DDP
Returns: Returns:
@ -779,15 +779,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
zero_state = dict() zero_state = dict()
device = get_accelerator().get_current_device() device = get_accelerator().get_current_device()
for param, state in self.optim.state.items(): for param, state in self.optim.state.items():
if pinned_state_dicts and param not in pinned_state_dicts:
pinned_state_dicts[param] = {}
zero_state[param] = copy.deepcopy(state) zero_state[param] = copy.deepcopy(state)
for k, v in state.items(): for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step": if isinstance(v, torch.Tensor) and k != "step":
if pinned_state_dicts and k not in pinned_state_dicts[param]:
pinned_state_dicts[param][k] = torch.empty_like(working_param, pin_memory=True, device="cpu")
working_param = self.master_to_working_param[id(param)] working_param = self.master_to_working_param[id(param)]
pg = self.param_to_pg[working_param] pg = self.param_to_pg[working_param]
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg) all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param).cpu() param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
zero_state[param][k] = param_state if pinned_state_dicts:
pinned_state_dicts[param][k].copy_(param_state)
zero_state[param][k] = pinned_state_dicts[param][k]
else:
zero_state[param][k] = param_state.cpu()
states_dict = self._pack_state(zero_state) states_dict = self._pack_state(zero_state)
@ -822,7 +830,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.optim.load_state_dict(zero_state_dict) self.optim.load_state_dict(zero_state_dict)
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]: def state_dict_shard(
self, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None
) -> Iterator[Tuple[Dict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
Only include the 'state' in state_dict. Only include the 'state' in state_dict.
@ -847,18 +857,27 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for param_idx, states in local_states.items(): for param_idx, states in local_states.items():
current_block_size = 0 current_block_size = 0
current_block = copy.deepcopy(states) current_block = copy.deepcopy(states)
if pinned_state_dicts and param_idx not in pinned_state_dicts:
pinned_state_dicts[param_idx] = {}
master_param = idx2master[param_idx] master_param = idx2master[param_idx]
working_param = self.master_to_working_param[id(master_param)] working_param = self.master_to_working_param[id(master_param)]
pg = self.param_to_pg[working_param] pg = self.param_to_pg[working_param]
for k, v in states.items(): for k, v in states.items():
if isinstance(v, torch.Tensor) and k != "step": if isinstance(v, torch.Tensor) and k != "step":
if pinned_state_dicts and k not in pinned_state_dicts[param_idx]:
pinned_state_dicts[param_idx][k] = torch.empty_like(
working_param, pin_memory=True, device="cpu"
)
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg) all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param).cpu() state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
if pinned_state_dicts:
pinned_state_dicts[param_idx][k].copy_(state_tensor)
current_block[k] = pinned_state_dicts[param_idx][k]
else:
current_block[k] = state_tensor.cpu()
current_block_size += state_tensor.numel() current_block_size += state_tensor.numel()
current_block[k] = state_tensor
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0: if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
yield ret_block, ret_block_size yield ret_block, ret_block_size

View File

@ -51,6 +51,8 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
model_ckpt_path = f"{model_ckpt_path}.pt" model_ckpt_path = f"{model_ckpt_path}.pt"
if not shard and use_async: if not shard and use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors" model_ckpt_path = f"{model_ckpt_path}.safetensors"
if not shard and use_async:
optimizer_ckpt_path = f"{tempdir}/optimizer.safetensors"
booster.save_model( booster.save_model(
model, model,
model_ckpt_path, model_ckpt_path,
@ -59,7 +61,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
) )
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here # lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, use_async=use_async)
booster.checkpoint_io._sync_d2h() booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io() booster.checkpoint_io._sync_io()
dist.barrier() dist.barrier()
@ -139,7 +141,6 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
assert torch.equal( assert torch.equal(
working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device) working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)
) )
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict()) check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())

View File

@ -0,0 +1,127 @@
import tempfile
from copy import deepcopy
import torch
from colossalai.utils.safetensors import load_flat, save_nested
try:
from tensornvme.async_file_io import AsyncFileWriter
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
from colossalai.testing import check_state_dict_equal
def test_save_load():
with tempfile.TemporaryDirectory() as tempdir:
optimizer_state_dict = {
0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
}
# group_dict = {"param_groups": [0, 1, 2]}
group_dict = {
"param_groups": [
{
"lr": 0.001,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0,
"bias_correction": True,
"params": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
61,
],
}
]
}
metadata = deepcopy(group_dict)
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict, metadata)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
load_state_dict = load_flat(optimizer_saved_path)
state_dict = load_state_dict["state"]
group = {"param_groups": load_state_dict["param_groups"]}
check_state_dict_equal(optimizer_state_dict, state_dict)
check_state_dict_equal(group_dict, group)
model_state_dict = {
"module.weight0": torch.rand((1024, 1024)),
"module.weight1": torch.rand((1024, 1024)),
"module.weight2": torch.rand((1024, 1024)),
}
model_saved_path = f"{tempdir}/save_model.safetensors"
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, model_state_dict)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
load_state_dict = load_flat(model_saved_path)
check_state_dict_equal(model_state_dict, load_state_dict)