mirror of https://github.com/hpcaitech/ColossalAI
[async io]supoort async io (#6137)
* support async optimizer save/load * fix * fix * support pin mem * Update low_level_zero_plugin.py * fix * fix * fix * fix * fixpull/6147/head
parent
b90835bd32
commit
eb69e640e5
|
@ -359,6 +359,7 @@ class Booster:
|
||||||
gather_dtensor: bool = True,
|
gather_dtensor: bool = True,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
|
use_async: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint.
|
Save optimizer to checkpoint.
|
||||||
|
@ -374,7 +375,9 @@ class Booster:
|
||||||
names to compose the keys in state_dict. Defaults to None.
|
names to compose the keys in state_dict. Defaults to None.
|
||||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||||
"""
|
"""
|
||||||
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
|
self.checkpoint_io.save_optimizer(
|
||||||
|
optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard, use_async=use_async
|
||||||
|
)
|
||||||
|
|
||||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
|
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
|
||||||
"""Save lr scheduler to checkpoint.
|
"""Save lr scheduler to checkpoint.
|
||||||
|
|
|
@ -94,7 +94,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||||
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
||||||
super().load_unsharded_model(model, checkpoint, strict=strict)
|
super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||||
|
|
||||||
def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
|
def save_unsharded_optimizer(
|
||||||
|
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save unsharded optimizer state dict to checkpoint.
|
Save unsharded optimizer state dict to checkpoint.
|
||||||
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
|
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
|
||||||
|
@ -178,7 +180,13 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||||
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
|
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
|
||||||
|
|
||||||
def save_sharded_optimizer(
|
def save_sharded_optimizer(
|
||||||
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
|
self,
|
||||||
|
optimizer: GeminiOptimizer,
|
||||||
|
checkpoint: Path,
|
||||||
|
gather_dtensor: bool,
|
||||||
|
prefix: str,
|
||||||
|
size_per_shard: int,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save sharded optimizer state dict to checkpoint folder.
|
Save sharded optimizer state dict to checkpoint folder.
|
||||||
|
|
|
@ -24,6 +24,7 @@ from colossalai.checkpoint_io.utils import (
|
||||||
get_shard_filename,
|
get_shard_filename,
|
||||||
load_param_groups_into_optimizer,
|
load_param_groups_into_optimizer,
|
||||||
load_shard_state_dict,
|
load_shard_state_dict,
|
||||||
|
load_state_dict,
|
||||||
load_states_into_optimizer,
|
load_states_into_optimizer,
|
||||||
save_param_groups,
|
save_param_groups,
|
||||||
save_state_dict,
|
save_state_dict,
|
||||||
|
@ -113,7 +114,9 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||||
|
|
||||||
|
|
||||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
def save_unsharded_optimizer(
|
||||||
|
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False, use_async: bool = False
|
||||||
|
):
|
||||||
"""Save optimizer to checkpoint but only on master process.
|
"""Save optimizer to checkpoint but only on master process.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -125,10 +128,35 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
# the `state_dict` in LowLevelZeroOptimizer has communication
|
# the `state_dict` in LowLevelZeroOptimizer has communication
|
||||||
# if only the master rank collect state_dict and save,
|
# if only the master rank collect state_dict and save,
|
||||||
# the communication on each rank would not match
|
# the communication on each rank would not match
|
||||||
state_dict = optimizer.state_dict()
|
if use_async:
|
||||||
|
if id(optimizer) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts[id(optimizer)] = {}
|
||||||
|
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||||
|
else:
|
||||||
|
pinned_state_dicts = None
|
||||||
|
state_dict = optimizer.state_dict(pinned_state_dicts)
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
|
if use_async:
|
||||||
|
from tensornvme.async_file_io import AsyncFileWriter
|
||||||
|
|
||||||
|
from colossalai.utils.safetensors import save_nested
|
||||||
|
|
||||||
|
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
|
||||||
|
save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]})
|
||||||
|
self.async_writers.append(f_writer)
|
||||||
|
else:
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||||
|
|
||||||
|
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||||
|
use_async = checkpoint.endswith(".safetensors")
|
||||||
|
if use_async:
|
||||||
|
from colossalai.utils.safetensors import load_flat
|
||||||
|
|
||||||
|
checkpoint = load_flat(checkpoint)
|
||||||
|
else:
|
||||||
|
checkpoint = load_state_dict(checkpoint)
|
||||||
|
optimizer.load_state_dict(checkpoint)
|
||||||
|
|
||||||
def save_sharded_optimizer(
|
def save_sharded_optimizer(
|
||||||
self,
|
self,
|
||||||
optimizer: OptimizerWrapper,
|
optimizer: OptimizerWrapper,
|
||||||
|
@ -136,6 +164,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
gather_dtensor: bool = False,
|
gather_dtensor: bool = False,
|
||||||
prefix: str = None,
|
prefix: str = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
|
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
|
||||||
|
@ -161,10 +190,16 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
# state_dict only provide only 'param_groups'
|
# state_dict only provide only 'param_groups'
|
||||||
state_dict = optimizer.optim.state_dict()
|
state_dict = optimizer.optim.state_dict()
|
||||||
# state shard would be handled by the low-level zero optimizer
|
# state shard would be handled by the low-level zero optimizer
|
||||||
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)
|
if use_async:
|
||||||
|
if id(optimizer) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts[id(optimizer)] = {}
|
||||||
|
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||||
|
else:
|
||||||
|
pinned_state_dicts = None
|
||||||
|
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts)
|
||||||
|
|
||||||
# Preparing file paths and index file.
|
# Preparing file paths and index file.
|
||||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
|
||||||
index_file = CheckpointIndexFile(checkpoint)
|
index_file = CheckpointIndexFile(checkpoint)
|
||||||
index_file.append_meta_data("param_groups", param_group_file)
|
index_file.append_meta_data("param_groups", param_group_file)
|
||||||
|
|
||||||
|
@ -184,6 +219,17 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
|
|
||||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
|
if use_async:
|
||||||
|
from tensornvme.async_file_io import AsyncFileWriter
|
||||||
|
|
||||||
|
from colossalai.utils.safetensors import save_nested
|
||||||
|
|
||||||
|
f_writer = AsyncFileWriter(
|
||||||
|
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
|
||||||
|
)
|
||||||
|
save_nested(f_writer, shard)
|
||||||
|
self.async_writers.append(f_writer)
|
||||||
|
else:
|
||||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||||
|
|
||||||
# Wrap up index file.
|
# Wrap up index file.
|
||||||
|
@ -223,6 +269,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||||
|
|
||||||
for shard_file in checkpoint_files:
|
for shard_file in checkpoint_files:
|
||||||
|
if shard_file.endswith(".safetensors"):
|
||||||
|
from colossalai.utils.safetensors import load_flat
|
||||||
|
|
||||||
|
state_dict = load_flat(shard_file)
|
||||||
|
else:
|
||||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||||
# shard state dict
|
# shard state dict
|
||||||
for param_idx, state in state_dict.items():
|
for param_idx, state in state_dict.items():
|
||||||
|
|
|
@ -52,7 +52,9 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||||
super().load_unsharded_optimizer(optimizer, checkpoint)
|
super().load_unsharded_optimizer(optimizer, checkpoint)
|
||||||
|
|
||||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
def save_unsharded_optimizer(
|
||||||
|
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint but only on master process.
|
Save optimizer to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
|
@ -113,13 +115,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||||
gather_dtensor: bool = True,
|
gather_dtensor: bool = True,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save optimizer to sharded checkpoint but only on master process.
|
Save optimizer to sharded checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
|
super().save_sharded_optimizer(
|
||||||
|
optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
|
||||||
|
)
|
||||||
|
|
||||||
def load_sharded_optimizer(
|
def load_sharded_optimizer(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -67,7 +67,9 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||||
full_model_state = model.state_dict()
|
full_model_state = model.state_dict()
|
||||||
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
|
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
|
||||||
|
|
||||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
def save_unsharded_optimizer(
|
||||||
|
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint but only on master process.
|
Save optimizer to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
|
@ -157,7 +159,13 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||||
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
|
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
|
||||||
|
|
||||||
def save_sharded_optimizer(
|
def save_sharded_optimizer(
|
||||||
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
checkpoint: str,
|
||||||
|
gather_dtensor: bool,
|
||||||
|
prefix: str,
|
||||||
|
size_per_shard: int,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint but only on master process.
|
Save optimizer to checkpoint but only on master process.
|
||||||
|
|
|
@ -213,6 +213,7 @@ class CheckpointIO(ABC):
|
||||||
gather_dtensor=True,
|
gather_dtensor=True,
|
||||||
prefix: str = None,
|
prefix: str = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
|
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
|
||||||
|
@ -229,11 +230,12 @@ class CheckpointIO(ABC):
|
||||||
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
|
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
|
||||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if shard:
|
if shard:
|
||||||
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
self.save_sharded_optimizer(
|
||||||
|
optimizer, checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)
|
||||||
|
|
||||||
# ========================================================
|
# ========================================================
|
||||||
# Abstract methods for model loading/saving implementation
|
# Abstract methods for model loading/saving implementation
|
||||||
|
@ -326,7 +328,13 @@ class CheckpointIO(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_sharded_optimizer(
|
def save_sharded_optimizer(
|
||||||
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
checkpoint: Path,
|
||||||
|
gather_dtensor: bool,
|
||||||
|
prefix: str,
|
||||||
|
size_per_shard: int,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save optimizer to sharded checkpoint.
|
Save optimizer to sharded checkpoint.
|
||||||
|
@ -340,7 +348,9 @@ class CheckpointIO(ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
|
def save_unsharded_optimizer(
|
||||||
|
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, use_async: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save optimizer to unsharded checkpoint.
|
Save optimizer to unsharded checkpoint.
|
||||||
|
|
||||||
|
|
|
@ -98,6 +98,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
gather_dtensor: bool,
|
gather_dtensor: bool,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
size_per_shard: int,
|
size_per_shard: int,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||||
|
@ -155,6 +156,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
checkpoint: Path,
|
checkpoint: Path,
|
||||||
gather_dtensor: bool,
|
gather_dtensor: bool,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
# TODO(FrankLeeeee): handle distributed tensors
|
# TODO(FrankLeeeee): handle distributed tensors
|
||||||
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
|
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
|
||||||
|
|
|
@ -416,6 +416,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
gather_dtensor: bool = True,
|
gather_dtensor: bool = True,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||||
|
@ -725,7 +726,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
# Update master params if mixed-precision training is enabled.
|
# Update master params if mixed-precision training is enabled.
|
||||||
model_before_wrapping.update_master_params()
|
model_before_wrapping.update_master_params()
|
||||||
|
|
||||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
def save_unsharded_optimizer(
|
||||||
|
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save optimizer state dict to a file with given path.
|
Save optimizer state dict to a file with given path.
|
||||||
|
|
||||||
|
|
|
@ -369,6 +369,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||||
gather_dtensor: bool = True,
|
gather_dtensor: bool = True,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||||
|
@ -729,7 +730,13 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
# Copied from colossalai.moe
|
# Copied from colossalai.moe
|
||||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
def save_unsharded_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: OptimizerWrapper,
|
||||||
|
checkpoint: str,
|
||||||
|
gather_dtensor: bool,
|
||||||
|
use_async: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save optimizer state dict to a file with given path.
|
Save optimizer state dict to a file with given path.
|
||||||
|
|
||||||
|
|
|
@ -24,9 +24,11 @@ from colossalai.utils.safetensors import move_and_save
|
||||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
STATES_NAME = "pytorch_optim.bin"
|
STATES_NAME = "pytorch_optim.bin"
|
||||||
|
SAFE_STATE_NAME = "optimizer.safetensors"
|
||||||
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
||||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||||
STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
|
STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
|
||||||
|
SAFE_STATES_INDEX_NAME = "optimizer.safetensors.index.json"
|
||||||
GROUP_FILE_NAME = "pytorch_optim_group.bin"
|
GROUP_FILE_NAME = "pytorch_optim_group.bin"
|
||||||
|
|
||||||
# ======================================
|
# ======================================
|
||||||
|
@ -838,14 +840,14 @@ def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
|
||||||
return weights_name, save_index_file
|
return weights_name, save_index_file
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer_base_filenames(prefix: str = None):
|
def get_optimizer_base_filenames(prefix: str = None, use_safetensors: bool = False):
|
||||||
"""
|
"""
|
||||||
generate base optimizer state filenames
|
generate base optimizer state filenames
|
||||||
"""
|
"""
|
||||||
states_name = STATES_NAME
|
states_name = SAFE_STATE_NAME if use_safetensors else STATES_NAME
|
||||||
states_name = add_prefix(states_name, prefix)
|
states_name = add_prefix(states_name, prefix)
|
||||||
|
|
||||||
save_index_file = STATES_INDEX_NAME
|
save_index_file = SAFE_STATES_INDEX_NAME if use_safetensors else STATES_INDEX_NAME
|
||||||
save_index_file = add_prefix(save_index_file, prefix)
|
save_index_file = add_prefix(save_index_file, prefix)
|
||||||
|
|
||||||
param_group_file = GROUP_FILE_NAME
|
param_group_file = GROUP_FILE_NAME
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, List, OrderedDict
|
from typing import Any, List, OrderedDict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -78,7 +78,9 @@ def check_state_dict_equal(
|
||||||
v1 = v1.to(v2.dtype)
|
v1 = v1.to(v2.dtype)
|
||||||
assert_close_loose(v1, v2)
|
assert_close_loose(v1, v2)
|
||||||
else:
|
else:
|
||||||
assert v1 == v2, f"{v1} not equals to {v2}"
|
if isinstance(v1, Tuple) and not isinstance(v2, Tuple):
|
||||||
|
v2 = tuple(v2)
|
||||||
|
assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}"
|
||||||
|
|
||||||
|
|
||||||
def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
|
def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
|
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
|
||||||
import json
|
import json
|
||||||
|
import warnings
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import _TYPES
|
from safetensors.torch import _TYPES, load_file, safe_open
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tensornvme.async_file_io import AsyncFileWriter
|
from tensornvme.async_file_io import AsyncFileWriter
|
||||||
|
@ -27,34 +28,93 @@ class PreparedData:
|
||||||
offset: int
|
offset: int
|
||||||
|
|
||||||
|
|
||||||
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor], List[str]]:
|
def flatten_dict(nested_dict, parent_key="", separator="^"):
|
||||||
|
"""
|
||||||
|
Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator.
|
||||||
|
|
||||||
|
nested_dict: The input nested dictionary.
|
||||||
|
parent_key: The parent key currently being processed.
|
||||||
|
separator: The separator used to join keys, default is '_', but can be customized to another symbol. :return: A flattened dictionary."
|
||||||
|
"""
|
||||||
|
items = []
|
||||||
|
for k, v in nested_dict.items():
|
||||||
|
new_key = f"{parent_key}{separator}{k}" if parent_key else str(k)
|
||||||
|
if isinstance(v, dict):
|
||||||
|
items.extend(flatten_dict(v, new_key, separator).items())
|
||||||
|
else:
|
||||||
|
v = torch.tensor(v, dtype=torch.float16) if not isinstance(v, torch.Tensor) else v
|
||||||
|
items.append((new_key, v))
|
||||||
|
|
||||||
|
return dict(items)
|
||||||
|
|
||||||
|
|
||||||
|
def unflatten_dict(flattened_dict, separator="^"):
|
||||||
|
"""
|
||||||
|
Restore a flattened dictionary back to a multi-level nested dictionary.
|
||||||
|
|
||||||
|
flattened_dict: The flattened dictionary.
|
||||||
|
separator: The separator used during flattening, default is '_', but can be customized to another symbol. :return: The restored nested dictionary.
|
||||||
|
"""
|
||||||
|
nested_dict = {}
|
||||||
|
for key, value in flattened_dict.items():
|
||||||
|
keys = key.split(separator)
|
||||||
|
try:
|
||||||
|
keys[0] = int(keys[0])
|
||||||
|
except ValueError:
|
||||||
|
warnings.warn(f"{key[0]} can't convert to integer")
|
||||||
|
d = nested_dict
|
||||||
|
for part in keys[:-1]:
|
||||||
|
if part not in d:
|
||||||
|
d[part] = {}
|
||||||
|
d = d[part]
|
||||||
|
assert isinstance(value, torch.Tensor)
|
||||||
|
d[keys[-1]] = value
|
||||||
|
|
||||||
|
return nested_dict
|
||||||
|
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
data: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
|
||||||
|
) -> Tuple[PreparedData, List[torch.Tensor], List[str]]:
|
||||||
|
if metadata is not None:
|
||||||
|
assert isinstance(metadata, dict)
|
||||||
|
for k, v in metadata.items():
|
||||||
|
metadata[k] = json.dumps(v)
|
||||||
|
assert isinstance(k, str)
|
||||||
|
assert isinstance(metadata[k], str)
|
||||||
|
|
||||||
tensors = []
|
tensors = []
|
||||||
tensor_keys = []
|
tensor_keys = []
|
||||||
metadata = {}
|
header = {}
|
||||||
offset = 0
|
offset = 0
|
||||||
|
|
||||||
|
if metadata is not None:
|
||||||
|
header["__metadata__"] = metadata
|
||||||
|
|
||||||
for name, tensor in data.items():
|
for name, tensor in data.items():
|
||||||
n = tensor.numel() * tensor.element_size()
|
n = tensor.numel() * tensor.element_size()
|
||||||
tensor_info = TensorInfo(
|
tensor_info = TensorInfo(
|
||||||
dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n)
|
dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n)
|
||||||
)
|
)
|
||||||
offset += n
|
offset += n
|
||||||
metadata[name] = asdict(tensor_info)
|
header[name] = asdict(tensor_info)
|
||||||
tensors.append(tensor)
|
tensors.append(tensor)
|
||||||
tensor_keys.append(name)
|
tensor_keys.append(name)
|
||||||
|
|
||||||
metadata_buf = json.dumps(metadata).encode("utf-8")
|
header_buf = json.dumps(header).encode("utf-8")
|
||||||
|
|
||||||
extra = (8 - len(metadata_buf) % 8) % 8
|
extra = (8 - len(header_buf) % 8) % 8
|
||||||
metadata_buf += b" " * extra
|
header_buf += b" " * extra
|
||||||
|
|
||||||
n = len(metadata_buf)
|
n = len(header_buf)
|
||||||
|
|
||||||
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors, tensor_keys
|
return PreparedData(n=n, header_bytes=header_buf, offset=offset), tensors, tensor_keys
|
||||||
|
|
||||||
|
|
||||||
def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
|
def save(
|
||||||
prepared_data, tensors, _ = prepare(state_dict)
|
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
|
||||||
|
) -> None:
|
||||||
|
prepared_data, tensors, _ = prepare(state_dict, metadata)
|
||||||
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
||||||
|
|
||||||
f_writer.write(n.to_bytes(8, byteorder="little"))
|
f_writer.write(n.to_bytes(8, byteorder="little"))
|
||||||
|
@ -64,6 +124,13 @@ def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None
|
||||||
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
|
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
|
||||||
|
|
||||||
|
|
||||||
|
def save_nested(
|
||||||
|
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
|
||||||
|
) -> None:
|
||||||
|
flatten_data = flatten_dict(state_dict)
|
||||||
|
save(f_writer, flatten_data, metadata)
|
||||||
|
|
||||||
|
|
||||||
def move_and_save(
|
def move_and_save(
|
||||||
f_writer: AsyncFileWriter,
|
f_writer: AsyncFileWriter,
|
||||||
state_dict: Dict[str, torch.Tensor],
|
state_dict: Dict[str, torch.Tensor],
|
||||||
|
@ -81,3 +148,16 @@ def move_and_save(
|
||||||
f_writer.write_tensor(state_dict[name], state_dict_pinned[name])
|
f_writer.write_tensor(state_dict[name], state_dict_pinned[name])
|
||||||
else:
|
else:
|
||||||
f_writer.write_tensor(state_dict[name])
|
f_writer.write_tensor(state_dict[name])
|
||||||
|
|
||||||
|
|
||||||
|
def load_flat(checkpoint_path):
|
||||||
|
with safe_open(checkpoint_path, framework="pt") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
|
state_dict_load = load_file(checkpoint_path)
|
||||||
|
state_dict = unflatten_dict(state_dict_load)
|
||||||
|
if metadata is None:
|
||||||
|
return state_dict
|
||||||
|
metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items()))
|
||||||
|
combined_state_dict = {"state": state_dict}
|
||||||
|
combined_state_dict.update(metadata)
|
||||||
|
return combined_state_dict
|
||||||
|
|
|
@ -770,7 +770,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
|
|
||||||
return {"state": packed_state, "param_groups": param_groups}
|
return {"state": packed_state, "param_groups": param_groups}
|
||||||
|
|
||||||
def state_dict(self) -> Dict:
|
def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None) -> Dict:
|
||||||
"""Return a state_dict same with DDP
|
"""Return a state_dict same with DDP
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -779,15 +779,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
zero_state = dict()
|
zero_state = dict()
|
||||||
device = get_accelerator().get_current_device()
|
device = get_accelerator().get_current_device()
|
||||||
for param, state in self.optim.state.items():
|
for param, state in self.optim.state.items():
|
||||||
|
if pinned_state_dicts and param not in pinned_state_dicts:
|
||||||
|
pinned_state_dicts[param] = {}
|
||||||
zero_state[param] = copy.deepcopy(state)
|
zero_state[param] = copy.deepcopy(state)
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
if isinstance(v, torch.Tensor) and k != "step":
|
if isinstance(v, torch.Tensor) and k != "step":
|
||||||
|
if pinned_state_dicts and k not in pinned_state_dicts[param]:
|
||||||
|
pinned_state_dicts[param][k] = torch.empty_like(working_param, pin_memory=True, device="cpu")
|
||||||
working_param = self.master_to_working_param[id(param)]
|
working_param = self.master_to_working_param[id(param)]
|
||||||
pg = self.param_to_pg[working_param]
|
pg = self.param_to_pg[working_param]
|
||||||
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
||||||
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
|
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
|
||||||
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param).cpu()
|
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
|
||||||
zero_state[param][k] = param_state
|
if pinned_state_dicts:
|
||||||
|
pinned_state_dicts[param][k].copy_(param_state)
|
||||||
|
zero_state[param][k] = pinned_state_dicts[param][k]
|
||||||
|
else:
|
||||||
|
zero_state[param][k] = param_state.cpu()
|
||||||
|
|
||||||
states_dict = self._pack_state(zero_state)
|
states_dict = self._pack_state(zero_state)
|
||||||
|
|
||||||
|
@ -822,7 +830,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
|
|
||||||
self.optim.load_state_dict(zero_state_dict)
|
self.optim.load_state_dict(zero_state_dict)
|
||||||
|
|
||||||
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
|
def state_dict_shard(
|
||||||
|
self, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None
|
||||||
|
) -> Iterator[Tuple[Dict, int]]:
|
||||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||||
Only include the 'state' in state_dict.
|
Only include the 'state' in state_dict.
|
||||||
|
|
||||||
|
@ -847,18 +857,27 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
for param_idx, states in local_states.items():
|
for param_idx, states in local_states.items():
|
||||||
current_block_size = 0
|
current_block_size = 0
|
||||||
current_block = copy.deepcopy(states)
|
current_block = copy.deepcopy(states)
|
||||||
|
if pinned_state_dicts and param_idx not in pinned_state_dicts:
|
||||||
|
pinned_state_dicts[param_idx] = {}
|
||||||
master_param = idx2master[param_idx]
|
master_param = idx2master[param_idx]
|
||||||
working_param = self.master_to_working_param[id(master_param)]
|
working_param = self.master_to_working_param[id(master_param)]
|
||||||
pg = self.param_to_pg[working_param]
|
pg = self.param_to_pg[working_param]
|
||||||
|
|
||||||
for k, v in states.items():
|
for k, v in states.items():
|
||||||
if isinstance(v, torch.Tensor) and k != "step":
|
if isinstance(v, torch.Tensor) and k != "step":
|
||||||
|
if pinned_state_dicts and k not in pinned_state_dicts[param_idx]:
|
||||||
|
pinned_state_dicts[param_idx][k] = torch.empty_like(
|
||||||
|
working_param, pin_memory=True, device="cpu"
|
||||||
|
)
|
||||||
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
||||||
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
|
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
|
||||||
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param).cpu()
|
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
|
||||||
|
if pinned_state_dicts:
|
||||||
|
pinned_state_dicts[param_idx][k].copy_(state_tensor)
|
||||||
|
current_block[k] = pinned_state_dicts[param_idx][k]
|
||||||
|
else:
|
||||||
|
current_block[k] = state_tensor.cpu()
|
||||||
current_block_size += state_tensor.numel()
|
current_block_size += state_tensor.numel()
|
||||||
current_block[k] = state_tensor
|
|
||||||
|
|
||||||
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
|
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
|
||||||
yield ret_block, ret_block_size
|
yield ret_block, ret_block_size
|
||||||
|
|
|
@ -51,6 +51,8 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
|
||||||
model_ckpt_path = f"{model_ckpt_path}.pt"
|
model_ckpt_path = f"{model_ckpt_path}.pt"
|
||||||
if not shard and use_async:
|
if not shard and use_async:
|
||||||
model_ckpt_path = f"{model_ckpt_path}.safetensors"
|
model_ckpt_path = f"{model_ckpt_path}.safetensors"
|
||||||
|
if not shard and use_async:
|
||||||
|
optimizer_ckpt_path = f"{tempdir}/optimizer.safetensors"
|
||||||
booster.save_model(
|
booster.save_model(
|
||||||
model,
|
model,
|
||||||
model_ckpt_path,
|
model_ckpt_path,
|
||||||
|
@ -59,7 +61,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
|
||||||
)
|
)
|
||||||
|
|
||||||
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
|
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
|
||||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
|
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, use_async=use_async)
|
||||||
booster.checkpoint_io._sync_d2h()
|
booster.checkpoint_io._sync_d2h()
|
||||||
booster.checkpoint_io._sync_io()
|
booster.checkpoint_io._sync_io()
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
@ -139,7 +141,6 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)
|
working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)
|
||||||
)
|
)
|
||||||
|
|
||||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||||
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())
|
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,127 @@
|
||||||
|
import tempfile
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.utils.safetensors import load_flat, save_nested
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tensornvme.async_file_io import AsyncFileWriter
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
||||||
|
|
||||||
|
from colossalai.testing import check_state_dict_equal
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_load():
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
optimizer_state_dict = {
|
||||||
|
0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
|
||||||
|
1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
|
||||||
|
2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
|
||||||
|
}
|
||||||
|
# group_dict = {"param_groups": [0, 1, 2]}
|
||||||
|
group_dict = {
|
||||||
|
"param_groups": [
|
||||||
|
{
|
||||||
|
"lr": 0.001,
|
||||||
|
"betas": (0.9, 0.999),
|
||||||
|
"eps": 1e-08,
|
||||||
|
"weight_decay": 0,
|
||||||
|
"bias_correction": True,
|
||||||
|
"params": [
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
5,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
13,
|
||||||
|
14,
|
||||||
|
15,
|
||||||
|
16,
|
||||||
|
17,
|
||||||
|
18,
|
||||||
|
19,
|
||||||
|
20,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
23,
|
||||||
|
24,
|
||||||
|
25,
|
||||||
|
26,
|
||||||
|
27,
|
||||||
|
28,
|
||||||
|
29,
|
||||||
|
30,
|
||||||
|
31,
|
||||||
|
32,
|
||||||
|
33,
|
||||||
|
34,
|
||||||
|
35,
|
||||||
|
36,
|
||||||
|
37,
|
||||||
|
38,
|
||||||
|
39,
|
||||||
|
40,
|
||||||
|
41,
|
||||||
|
42,
|
||||||
|
43,
|
||||||
|
44,
|
||||||
|
45,
|
||||||
|
46,
|
||||||
|
47,
|
||||||
|
48,
|
||||||
|
49,
|
||||||
|
50,
|
||||||
|
51,
|
||||||
|
52,
|
||||||
|
53,
|
||||||
|
54,
|
||||||
|
55,
|
||||||
|
56,
|
||||||
|
57,
|
||||||
|
58,
|
||||||
|
59,
|
||||||
|
60,
|
||||||
|
61,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
metadata = deepcopy(group_dict)
|
||||||
|
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
|
||||||
|
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||||
|
|
||||||
|
save_nested(f_writer, optimizer_state_dict, metadata)
|
||||||
|
f_writer.sync_before_step()
|
||||||
|
f_writer.synchronize()
|
||||||
|
f_writer.fp.close()
|
||||||
|
|
||||||
|
load_state_dict = load_flat(optimizer_saved_path)
|
||||||
|
state_dict = load_state_dict["state"]
|
||||||
|
group = {"param_groups": load_state_dict["param_groups"]}
|
||||||
|
check_state_dict_equal(optimizer_state_dict, state_dict)
|
||||||
|
check_state_dict_equal(group_dict, group)
|
||||||
|
|
||||||
|
model_state_dict = {
|
||||||
|
"module.weight0": torch.rand((1024, 1024)),
|
||||||
|
"module.weight1": torch.rand((1024, 1024)),
|
||||||
|
"module.weight2": torch.rand((1024, 1024)),
|
||||||
|
}
|
||||||
|
model_saved_path = f"{tempdir}/save_model.safetensors"
|
||||||
|
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||||
|
save_nested(f_writer, model_state_dict)
|
||||||
|
f_writer.sync_before_step()
|
||||||
|
f_writer.synchronize()
|
||||||
|
f_writer.fp.close()
|
||||||
|
|
||||||
|
load_state_dict = load_flat(model_saved_path)
|
||||||
|
check_state_dict_equal(model_state_dict, load_state_dict)
|
Loading…
Reference in New Issue