mirror of https://github.com/hpcaitech/ColossalAI
[checkpointio] General Checkpointing of Sharded Optimizers (#3984)
parent
8bcad73677
commit
c9cff7e7fa
|
@ -12,7 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
||||||
from colossalai.checkpoint_io.utils import get_base_filenames, get_shard_filename, save_state_dict
|
from colossalai.checkpoint_io.utils import get_model_base_filenames, get_shard_filename, save_state_dict
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
@ -76,14 +76,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||||
model: GeminiDDP,
|
model: GeminiDDP,
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
gather_dtensor: bool = False,
|
gather_dtensor: bool = False,
|
||||||
variant: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
max_shard_size: int = 1024,
|
max_shard_size: int = 1024,
|
||||||
use_safetensors: bool = False):
|
use_safetensors: bool = False):
|
||||||
"""
|
"""
|
||||||
Save sharded model
|
Save sharded model
|
||||||
"""
|
"""
|
||||||
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
|
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
|
||||||
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
|
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||||
total_size = 0
|
total_size = 0
|
||||||
index_file = CheckpointIndexFile(checkpoint_path)
|
index_file = CheckpointIndexFile(checkpoint_path)
|
||||||
for idx, shard_pair in enumerate(state_dict_shard):
|
for idx, shard_pair in enumerate(state_dict_shard):
|
||||||
|
|
|
@ -32,7 +32,6 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||||
"""
|
"""
|
||||||
Save model to checkpoint but only on master process.
|
Save model to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||||
|
|
||||||
|
@ -54,11 +53,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
gather_dtensor: bool = False,
|
gather_dtensor: bool = False,
|
||||||
variant: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
max_shard_size: int = 1024,
|
max_shard_size: int = 1024,
|
||||||
use_safetensors: bool = False):
|
use_safetensors: bool = False):
|
||||||
|
"""
|
||||||
|
Save model to checkpoint but only on master process.
|
||||||
|
"""
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors)
|
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
|
||||||
|
|
||||||
|
def save_sharded_optimier(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
|
||||||
|
size_per_shard: int):
|
||||||
|
"""
|
||||||
|
Save optimizer to checkpoint but only on master process.
|
||||||
|
"""
|
||||||
|
if self.coordinator.is_master():
|
||||||
|
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||||
|
|
||||||
|
|
||||||
class TorchDDPModel(ModelWrapper):
|
class TorchDDPModel(ModelWrapper):
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
|
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import warnings
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||||
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
|
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
|
||||||
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
|
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
|
||||||
|
|
||||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
|
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
|
||||||
size_per_shard: int, use_safetensors: bool):
|
size_per_shard: int, use_safetensors: bool):
|
||||||
"""
|
"""
|
||||||
Save model to checkpoint but only on master process.
|
Save model to checkpoint but only on master process.
|
||||||
|
@ -87,13 +87,14 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
||||||
|
|
||||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
|
||||||
|
size_per_shard: int):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint but only on master process.
|
Save optimizer to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
|
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
|
||||||
"""
|
"""
|
||||||
Load optimizer to checkpoint but only on master process.
|
Load optimizer to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -103,7 +103,7 @@ class CheckpointIO(ABC):
|
||||||
checkpoint: str,
|
checkpoint: str,
|
||||||
shard: bool = False,
|
shard: bool = False,
|
||||||
gather_dtensor: bool = True,
|
gather_dtensor: bool = True,
|
||||||
variant: str = None,
|
prefix: str = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
use_safetensors: bool = False):
|
use_safetensors: bool = False):
|
||||||
"""
|
"""
|
||||||
|
@ -128,7 +128,7 @@ class CheckpointIO(ABC):
|
||||||
multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure
|
multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure
|
||||||
that the checkpoint path is a directory path instead of a file path.
|
that the checkpoint path is a directory path instead of a file path.
|
||||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
||||||
variant (str): If specified, weights are saved in the format pytorch_model.<variant>.bin. Default: None.
|
prefix (str): If specified, weights are saved in the format pytorch_model.<prefix>.bin. Default: None.
|
||||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
||||||
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
||||||
"""
|
"""
|
||||||
|
@ -137,11 +137,11 @@ class CheckpointIO(ABC):
|
||||||
model = model.unwrap()
|
model = model.unwrap()
|
||||||
|
|
||||||
if shard:
|
if shard:
|
||||||
self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors)
|
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
|
||||||
else:
|
else:
|
||||||
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||||
|
|
||||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
|
||||||
"""
|
"""
|
||||||
Load optimizer from checkpoint.
|
Load optimizer from checkpoint.
|
||||||
|
|
||||||
|
@ -157,7 +157,7 @@ class CheckpointIO(ABC):
|
||||||
|
|
||||||
if index_file_exists:
|
if index_file_exists:
|
||||||
# the existence of index file means it is a sharded checkpoint
|
# the existence of index file means it is a sharded checkpoint
|
||||||
self.load_sharded_optimizer(optimizer, index_file_path)
|
self.load_sharded_optimizer(optimizer, index_file_path, prefix, size_per_shard)
|
||||||
else:
|
else:
|
||||||
self.load_unsharded_optimizer(optimizer, checkpoint)
|
self.load_unsharded_optimizer(optimizer, checkpoint)
|
||||||
|
|
||||||
|
@ -218,7 +218,7 @@ class CheckpointIO(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
|
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
|
||||||
size_per_shard: int, use_safetensors: bool):
|
size_per_shard: int, use_safetensors: bool):
|
||||||
"""
|
"""
|
||||||
Save model to sharded checkpoint.
|
Save model to sharded checkpoint.
|
||||||
|
|
|
@ -11,15 +11,21 @@ from torch.optim import Optimizer
|
||||||
from .checkpoint_io_base import CheckpointIO
|
from .checkpoint_io_base import CheckpointIO
|
||||||
from .index_file import CheckpointIndexFile
|
from .index_file import CheckpointIndexFile
|
||||||
from .utils import (
|
from .utils import (
|
||||||
get_base_filenames,
|
get_model_base_filenames,
|
||||||
|
get_optimizer_base_filenames,
|
||||||
get_shard_filename,
|
get_shard_filename,
|
||||||
has_index_file,
|
has_index_file,
|
||||||
is_safetensors_available,
|
is_safetensors_available,
|
||||||
|
load_param_groups_into_optimizer,
|
||||||
load_shard_state_dict,
|
load_shard_state_dict,
|
||||||
load_state_dict,
|
load_state_dict,
|
||||||
load_state_dict_into_model,
|
load_state_dict_into_model,
|
||||||
|
load_states_into_optimizer,
|
||||||
|
save_param_groups,
|
||||||
save_state_dict,
|
save_state_dict,
|
||||||
shard_checkpoint,
|
shard_model_checkpoint,
|
||||||
|
shard_optimizer_checkpoint,
|
||||||
|
sharded_optimizer_loading_epilogue,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = ['GeneralCheckpointIO']
|
__all__ = ['GeneralCheckpointIO']
|
||||||
|
@ -44,12 +50,30 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
# save the checkpoint
|
# save the checkpoint
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
|
||||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
"""
|
||||||
|
Load sharded optimizer with the given path to index file.
|
||||||
|
"""
|
||||||
|
optimizer.load_state_dict
|
||||||
|
# Read checkpoint index file.
|
||||||
|
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||||
|
|
||||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
# Load param_groups
|
||||||
checkpoint = load_state_dict(checkpoint)
|
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||||
optimizer.load_state_dict(checkpoint)
|
if param_group_path is None:
|
||||||
|
raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
|
||||||
|
Lacking param group file under current directory.')
|
||||||
|
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
|
||||||
|
|
||||||
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||||
|
|
||||||
|
for shard_file in checkpoint_files:
|
||||||
|
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||||
|
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||||
|
del state_dict
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
sharded_optimizer_loading_epilogue(optimizer)
|
||||||
|
|
||||||
def save_sharded_optimizer(
|
def save_sharded_optimizer(
|
||||||
self,
|
self,
|
||||||
|
@ -59,7 +83,54 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
prefix: str,
|
prefix: str,
|
||||||
size_per_shard: int,
|
size_per_shard: int,
|
||||||
):
|
):
|
||||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
"""
|
||||||
|
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||||
|
The following files will be created under the path:
|
||||||
|
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
|
||||||
|
- A group file (pytorch_optim_group.bin) recording information of param_groups
|
||||||
|
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
|
||||||
|
"""
|
||||||
|
if os.path.isfile(checkpoint):
|
||||||
|
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||||
|
return
|
||||||
|
|
||||||
|
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Offload optimizer states. States are broken into shards within max_shard_size.
|
||||||
|
state_dict = optimizer.state_dict()
|
||||||
|
sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard)
|
||||||
|
|
||||||
|
# Preparing file paths and index file.
|
||||||
|
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||||
|
index_file = CheckpointIndexFile(checkpoint)
|
||||||
|
|
||||||
|
# Store the information of param groups to param_group_file.
|
||||||
|
index_file.append_meta_data("param_groups", param_group_file)
|
||||||
|
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||||
|
save_param_groups(state_dict, group_file_path)
|
||||||
|
|
||||||
|
# Save shards of optimizer states.
|
||||||
|
total_size = 0
|
||||||
|
for idx, shard_pair in enumerate(sharded_state):
|
||||||
|
shard, current_size = shard_pair
|
||||||
|
shard_file = get_shard_filename(states_name, idx)
|
||||||
|
total_size = total_size + current_size
|
||||||
|
for param_id in shard.keys():
|
||||||
|
index_file.append_weight_map(str(param_id), shard_file)
|
||||||
|
|
||||||
|
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||||
|
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||||
|
|
||||||
|
# Wrap up index file.
|
||||||
|
index_file.append_meta_data("total_size", total_size)
|
||||||
|
index_file.write_index_file(save_index_file)
|
||||||
|
logging.info(f"The optimizer is going to be split to checkpoint shards. "
|
||||||
|
f"You can find where each parameters has been saved in the "
|
||||||
|
f"index located at {save_index_file}.")
|
||||||
|
|
||||||
|
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||||
|
checkpoint = load_state_dict(checkpoint)
|
||||||
|
optimizer.load_state_dict(checkpoint)
|
||||||
|
|
||||||
def save_unsharded_optimizer(
|
def save_unsharded_optimizer(
|
||||||
self,
|
self,
|
||||||
|
@ -74,7 +145,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
gather_dtensor: bool = False,
|
gather_dtensor: bool = False,
|
||||||
variant: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
max_shard_size: int = 1024,
|
max_shard_size: int = 1024,
|
||||||
use_safetensors: bool = False):
|
use_safetensors: bool = False):
|
||||||
"""
|
"""
|
||||||
|
@ -89,9 +160,9 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
|
|
||||||
# shard checkpoint
|
# shard checkpoint
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
|
state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size)
|
||||||
|
|
||||||
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
|
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||||
total_size = 0
|
total_size = 0
|
||||||
index_file = CheckpointIndexFile(checkpoint_path)
|
index_file = CheckpointIndexFile(checkpoint_path)
|
||||||
for idx, shard_pair in enumerate(state_dict_shard):
|
for idx, shard_pair in enumerate(state_dict_shard):
|
||||||
|
@ -128,7 +199,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
|
|
||||||
# read checkpoint index file
|
# read checkpoint index file
|
||||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||||
missing_keys = []
|
missing_keys = []
|
||||||
|
|
||||||
for shard_file in checkpoint_files:
|
for shard_file in checkpoint_files:
|
||||||
|
|
|
@ -111,7 +111,7 @@ class CheckpointIndexFile:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_checkpoint_fileanames(self) -> List[str]:
|
def get_checkpoint_filenames(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get the set of checkpoint filenames in the weight map.
|
Get the set of checkpoint filenames in the weight map.
|
||||||
|
|
||||||
|
@ -159,6 +159,18 @@ class CheckpointIndexFile:
|
||||||
"""
|
"""
|
||||||
return list(self.weight_map.keys())
|
return list(self.weight_map.keys())
|
||||||
|
|
||||||
|
def get_param_group_filename(self) -> Union[str, None]:
|
||||||
|
"""
|
||||||
|
Get the file name of param_group file if this is a checkpoint for optimizer.
|
||||||
|
Returns:
|
||||||
|
str: param_group file name
|
||||||
|
"""
|
||||||
|
filename = self.metadata.get("param_groups", None)
|
||||||
|
if filename:
|
||||||
|
return str(self.root_path.joinpath(filename))
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def write_index_file(self, save_index_file):
|
def write_index_file(self, save_index_file):
|
||||||
"""
|
"""
|
||||||
Write index file.
|
Write index file.
|
||||||
|
|
|
@ -1,17 +1,24 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
import re
|
import re
|
||||||
|
from collections import abc as container_abcs
|
||||||
|
from collections import defaultdict
|
||||||
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||||
|
|
||||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
|
STATES_NAME = "pytorch_optim.bin"
|
||||||
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
||||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||||
|
STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
|
||||||
|
GROUP_FILE_NAME = "pytorch_optim_group.bin"
|
||||||
|
|
||||||
# ======================================
|
# ======================================
|
||||||
# General helper functions
|
# General helper functions
|
||||||
|
@ -81,7 +88,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
|
||||||
# ======================================
|
# ======================================
|
||||||
# Helper functions for saving shard file
|
# Helper functions for saving shard file
|
||||||
# ======================================
|
# ======================================
|
||||||
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||||
"""
|
"""
|
||||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||||
given size.
|
given size.
|
||||||
|
@ -110,6 +117,50 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It
|
||||||
yield current_block, current_block_size
|
yield current_block, current_block_size
|
||||||
|
|
||||||
|
|
||||||
|
def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||||
|
"""
|
||||||
|
Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||||
|
given size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
|
||||||
|
states = state_dict['state']
|
||||||
|
|
||||||
|
current_block = {}
|
||||||
|
current_block_size = 0
|
||||||
|
|
||||||
|
for param_id, state in states.items():
|
||||||
|
|
||||||
|
ret_block = None
|
||||||
|
ret_block_size = 0
|
||||||
|
|
||||||
|
# A state might contain more than one tensors.
|
||||||
|
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
|
||||||
|
state_size = 0
|
||||||
|
isDTensor = False
|
||||||
|
for state_tensor in state.values():
|
||||||
|
# If the states are stored as DTensors, mark isDTensor as true.
|
||||||
|
if type(state_tensor) == DTensor:
|
||||||
|
isDTensor = True
|
||||||
|
state_size += calculate_tensor_size(state_tensor)
|
||||||
|
|
||||||
|
if not isDTensor:
|
||||||
|
|
||||||
|
if current_block_size + state_size > max_shard_size:
|
||||||
|
ret_block = current_block
|
||||||
|
ret_block_size = current_block_size
|
||||||
|
current_block = {}
|
||||||
|
current_block_size = 0
|
||||||
|
|
||||||
|
current_block[param_id] = state
|
||||||
|
current_block_size += state_size
|
||||||
|
|
||||||
|
if ret_block != None:
|
||||||
|
yield ret_block, ret_block_size
|
||||||
|
|
||||||
|
yield current_block, current_block_size
|
||||||
|
|
||||||
|
|
||||||
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
|
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
|
||||||
"""
|
"""
|
||||||
load shard state dict into model
|
load shard state dict into model
|
||||||
|
@ -179,6 +230,96 @@ def load_state_dict_into_model(model: nn.Module,
|
||||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||||
|
|
||||||
|
|
||||||
|
def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict:
|
||||||
|
"""
|
||||||
|
Load information of param_groups into an initialized optimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Load list of param_groups from given file path.
|
||||||
|
# The params in saved_groups are in the form of integer indices.
|
||||||
|
saved_groups = torch.load(param_group_path)
|
||||||
|
if not isinstance(saved_groups, List):
|
||||||
|
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
|
||||||
|
|
||||||
|
# The params in param_groups are in the form of pytorch tensors.
|
||||||
|
# For more details, please view source code of Optimizer class in pytorch.
|
||||||
|
param_groups = optimizer.param_groups
|
||||||
|
|
||||||
|
# Check the compatibility of saved_groups and param_groups.
|
||||||
|
if len(param_groups) != len(saved_groups):
|
||||||
|
raise ValueError("loaded state dict has a different number of original parameter groups")
|
||||||
|
param_lens = (len(g['params']) for g in param_groups)
|
||||||
|
saved_lens = (len(g['params']) for g in saved_groups)
|
||||||
|
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
|
||||||
|
raise ValueError("loaded state dict contains a parameter group "
|
||||||
|
"that doesn't match the size of optimizer's group")
|
||||||
|
|
||||||
|
# Creating mapping from id to parameters.
|
||||||
|
id_map = {
|
||||||
|
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
|
||||||
|
)), chain.from_iterable((g['params'] for g in param_groups)))
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update parameter groups, setting their 'params' value.
|
||||||
|
def update_group(group, new_group):
|
||||||
|
new_group['params'] = group['params']
|
||||||
|
return new_group
|
||||||
|
|
||||||
|
updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)]
|
||||||
|
|
||||||
|
optimizer.__dict__.update({'param_groups': updated_groups})
|
||||||
|
return id_map
|
||||||
|
|
||||||
|
|
||||||
|
def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: dict):
|
||||||
|
r"""Copies states from `state_dict` into an Optimizer object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer(Optimizer): An initialized Optimizer object to be loaded
|
||||||
|
state_dict(dict): a mapping from tensor index (an integer)
|
||||||
|
to its states to be loaded (a mapping from state name to a tensor).
|
||||||
|
id_map(dict): a mapping from tensor index (an integer)
|
||||||
|
to its corresponding parameter (a tensor) whose states will be updated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def cast(param, value, key=None):
|
||||||
|
r"""Make a deep copy of value, casting all tensors to device of param."""
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
# Floating-point types are a bit special here. They are the only ones
|
||||||
|
# that are assumed to always match the type of params.
|
||||||
|
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
|
||||||
|
if (key != "step"):
|
||||||
|
if param.is_floating_point():
|
||||||
|
value = value.to(param.dtype)
|
||||||
|
value = value.to(param.device)
|
||||||
|
return value
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
return {k: cast(param, v, key=k) for k, v in value.items()}
|
||||||
|
elif isinstance(value, container_abcs.Iterable):
|
||||||
|
return type(value)(cast(param, v) for v in value)
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
|
||||||
|
# Copy state assigned to params (and cast tensors to appropriate types).
|
||||||
|
# State that is not assigned to params is copied as is (needed for
|
||||||
|
# backward compatibility).
|
||||||
|
new_states = defaultdict(dict)
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if k in id_map:
|
||||||
|
param = id_map[k]
|
||||||
|
new_states[param] = cast(param, v)
|
||||||
|
else:
|
||||||
|
new_states[k] = v
|
||||||
|
|
||||||
|
optimzier.state.update(new_states)
|
||||||
|
|
||||||
|
|
||||||
|
def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
|
||||||
|
# Do the cleaning up as in src code of Pytorch.
|
||||||
|
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||||
|
optimizer.defaults.setdefault('differentiable', False)
|
||||||
|
|
||||||
|
|
||||||
# ======================================
|
# ======================================
|
||||||
# Helper functions for saving state dict
|
# Helper functions for saving state dict
|
||||||
# ======================================
|
# ======================================
|
||||||
|
@ -203,6 +344,18 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
|
||||||
torch.save(state_dict, checkpoint_file_path)
|
torch.save(state_dict, checkpoint_file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def save_param_groups(state_dict: dict, group_file_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Save information of param_groups to given file path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict): state dict.
|
||||||
|
group_file_path (str): path to the group file.
|
||||||
|
"""
|
||||||
|
param_groups = state_dict["param_groups"]
|
||||||
|
torch.save(param_groups, group_file_path)
|
||||||
|
|
||||||
|
|
||||||
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
|
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
|
||||||
"""
|
"""
|
||||||
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
|
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
|
||||||
|
@ -392,28 +545,44 @@ def load_state_dict(checkpoint_file_path: Path):
|
||||||
return torch.load(checkpoint_file_path)
|
return torch.load(checkpoint_file_path)
|
||||||
|
|
||||||
|
|
||||||
def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
|
||||||
if variant is not None and len(variant) > 0:
|
if prefix is not None and len(prefix) > 0:
|
||||||
splits = weights_name.split(".")
|
splits = weights_name.split(".")
|
||||||
splits = splits[:-1] + [variant] + splits[-1:]
|
splits = splits[:-1] + [prefix] + splits[-1:]
|
||||||
weights_name = ".".join(splits)
|
weights_name = ".".join(splits)
|
||||||
|
|
||||||
return weights_name
|
return weights_name
|
||||||
|
|
||||||
|
|
||||||
def get_base_filenames(variant: str = None, use_safetensors: bool = False):
|
def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
|
||||||
"""
|
"""
|
||||||
generate base weight filenames
|
generate base model weight filenames
|
||||||
"""
|
"""
|
||||||
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
||||||
weights_name = add_variant(weights_name, variant)
|
weights_name = add_prefix(weights_name, prefix)
|
||||||
|
|
||||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
||||||
save_index_file = add_variant(save_index_file, variant)
|
save_index_file = add_prefix(save_index_file, prefix)
|
||||||
|
|
||||||
return weights_name, save_index_file
|
return weights_name, save_index_file
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimizer_base_filenames(prefix: str = None):
|
||||||
|
"""
|
||||||
|
generate base optimizer state filenames
|
||||||
|
"""
|
||||||
|
states_name = STATES_NAME
|
||||||
|
states_name = add_prefix(states_name, prefix)
|
||||||
|
|
||||||
|
save_index_file = STATES_INDEX_NAME
|
||||||
|
save_index_file = add_prefix(save_index_file, prefix)
|
||||||
|
|
||||||
|
param_group_file = GROUP_FILE_NAME
|
||||||
|
param_group_file = add_prefix(param_group_file, prefix)
|
||||||
|
|
||||||
|
return states_name, save_index_file, param_group_file
|
||||||
|
|
||||||
|
|
||||||
def get_shard_filename(weights_name: str, idx: int):
|
def get_shard_filename(weights_name: str, idx: int):
|
||||||
"""
|
"""
|
||||||
get shard file name
|
get shard file name
|
||||||
|
|
|
@ -60,7 +60,7 @@ def test_unsharded_checkpoint(use_safetensors: bool):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('use_safetensors', [True, False])
|
@pytest.mark.parametrize('use_safetensors', [True, False])
|
||||||
def test_sharded_checkpoint(use_safetensors: bool):
|
def test_sharded_model_checkpoint(use_safetensors: bool):
|
||||||
# create a model and optimizer
|
# create a model and optimizer
|
||||||
model = resnet18()
|
model = resnet18()
|
||||||
optimizer = Adam(model.parameters(), lr=0.001)
|
optimizer = Adam(model.parameters(), lr=0.001)
|
||||||
|
@ -100,3 +100,101 @@ def test_sharded_checkpoint(use_safetensors: bool):
|
||||||
# check for model and optimizer state dict recursively
|
# check for model and optimizer state dict recursively
|
||||||
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
||||||
|
|
||||||
|
|
||||||
|
def test_sharded_optimizer_checkpoint():
|
||||||
|
|
||||||
|
# create a model and optimizer
|
||||||
|
model = resnet18()
|
||||||
|
optimizer = Adam(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
# create test data sample
|
||||||
|
x = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
# run fwd and bwd
|
||||||
|
y = model(x)
|
||||||
|
loss = y.sum()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# create temp directories for checkpoint
|
||||||
|
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||||
|
optimizer_ckpt_dir = tempfile.TemporaryDirectory()
|
||||||
|
|
||||||
|
# save the model and optimizer
|
||||||
|
ckpt_io = GeneralCheckpointIO()
|
||||||
|
|
||||||
|
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
|
||||||
|
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
|
||||||
|
|
||||||
|
# create new model
|
||||||
|
new_model = resnet18()
|
||||||
|
new_optimizer = Adam(new_model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True)
|
||||||
|
ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name))
|
||||||
|
|
||||||
|
# check for model and optimizer state dict recursively
|
||||||
|
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||||
|
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
||||||
|
|
||||||
|
# continue running fwd and bwd
|
||||||
|
for _ in range(5):
|
||||||
|
y = new_model(x)
|
||||||
|
loss = y.sum()
|
||||||
|
loss.backward()
|
||||||
|
new_optimizer.step()
|
||||||
|
|
||||||
|
# save the newly got optimizer
|
||||||
|
ckpt_io.save_model(new_model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
|
||||||
|
ckpt_io.save_optimizer(new_optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
|
||||||
|
|
||||||
|
# create another new model
|
||||||
|
new_new_model = resnet18()
|
||||||
|
new_new_optimizer = Adam(new_new_model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
ckpt_io.load_model(new_new_model, str(model_ckpt_dir.name), strict=True)
|
||||||
|
ckpt_io.load_optimizer(new_new_optimizer, str(optimizer_ckpt_dir.name))
|
||||||
|
|
||||||
|
# check for model and optimizer state dict recursively
|
||||||
|
check_state_dict_equal(new_model.state_dict(), new_new_model.state_dict())
|
||||||
|
check_state_dict_equal(new_optimizer.state_dict(), new_new_optimizer.state_dict())
|
||||||
|
|
||||||
|
|
||||||
|
def test_sharded_optimizer_multiple_param_groups():
|
||||||
|
|
||||||
|
# create a model and optimizer
|
||||||
|
model = resnet18()
|
||||||
|
optimizer = Adam([{'params': model.layer1.parameters()}, \
|
||||||
|
{'params': model.layer2.parameters(), 'lr': 0.002}], lr=0.001)
|
||||||
|
|
||||||
|
# create test data sample
|
||||||
|
x = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
# run fwd and bwd
|
||||||
|
y = model(x)
|
||||||
|
loss = y.sum()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# create temp directories for checkpoint
|
||||||
|
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||||
|
optimizer_ckpt_dir = tempfile.TemporaryDirectory()
|
||||||
|
|
||||||
|
# save the model and optimizer
|
||||||
|
ckpt_io = GeneralCheckpointIO()
|
||||||
|
|
||||||
|
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
|
||||||
|
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
|
||||||
|
|
||||||
|
# create new model
|
||||||
|
new_model = resnet18()
|
||||||
|
new_optimizer = Adam([{'params': new_model.layer1.parameters()}, \
|
||||||
|
{'params': new_model.layer2.parameters(), 'lr': 0.002}], lr=0.001)
|
||||||
|
|
||||||
|
ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True)
|
||||||
|
ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name))
|
||||||
|
|
||||||
|
# check for model and optimizer state dict recursively
|
||||||
|
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||||
|
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
||||||
|
|
Loading…
Reference in New Issue