mirror of https://github.com/hpcaitech/ColossalAI
[checkpointio] General Checkpointing of Sharded Optimizers (#3984)
parent
8bcad73677
commit
c9cff7e7fa
|
@ -12,7 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io.utils import get_base_filenames, get_shard_filename, save_state_dict
|
||||
from colossalai.checkpoint_io.utils import get_model_base_filenames, get_shard_filename, save_state_dict
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
|
@ -76,14 +76,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
model: GeminiDDP,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
Save sharded model
|
||||
"""
|
||||
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
|
||||
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
total_size = 0
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
for idx, shard_pair in enumerate(state_dict_shard):
|
||||
|
|
|
@ -32,7 +32,6 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
if self.coordinator.is_master():
|
||||
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||
|
||||
|
@ -54,11 +53,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors)
|
||||
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
|
||||
|
||||
def save_sharded_optimier(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
|
||||
size_per_shard: int):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||
|
||||
|
||||
class TorchDDPModel(ModelWrapper):
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import warnings
|
||||
from packaging import version
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
@ -69,7 +69,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
|
||||
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
|
||||
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
|
||||
size_per_shard: int, use_safetensors: bool):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
|
@ -87,13 +87,14 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
"""
|
||||
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
||||
|
||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
|
||||
size_per_shard: int):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
|
||||
"""
|
||||
Load optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
|
|
|
@ -103,7 +103,7 @@ class CheckpointIO(ABC):
|
|||
checkpoint: str,
|
||||
shard: bool = False,
|
||||
gather_dtensor: bool = True,
|
||||
variant: str = None,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
|
@ -128,7 +128,7 @@ class CheckpointIO(ABC):
|
|||
multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure
|
||||
that the checkpoint path is a directory path instead of a file path.
|
||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
||||
variant (str): If specified, weights are saved in the format pytorch_model.<variant>.bin. Default: None.
|
||||
prefix (str): If specified, weights are saved in the format pytorch_model.<prefix>.bin. Default: None.
|
||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
||||
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
||||
"""
|
||||
|
@ -137,11 +137,11 @@ class CheckpointIO(ABC):
|
|||
model = model.unwrap()
|
||||
|
||||
if shard:
|
||||
self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors)
|
||||
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
|
||||
else:
|
||||
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
|
||||
"""
|
||||
Load optimizer from checkpoint.
|
||||
|
||||
|
@ -157,7 +157,7 @@ class CheckpointIO(ABC):
|
|||
|
||||
if index_file_exists:
|
||||
# the existence of index file means it is a sharded checkpoint
|
||||
self.load_sharded_optimizer(optimizer, index_file_path)
|
||||
self.load_sharded_optimizer(optimizer, index_file_path, prefix, size_per_shard)
|
||||
else:
|
||||
self.load_unsharded_optimizer(optimizer, checkpoint)
|
||||
|
||||
|
@ -218,7 +218,7 @@ class CheckpointIO(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
|
||||
size_per_shard: int, use_safetensors: bool):
|
||||
"""
|
||||
Save model to sharded checkpoint.
|
||||
|
|
|
@ -11,15 +11,21 @@ from torch.optim import Optimizer
|
|||
from .checkpoint_io_base import CheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
from .utils import (
|
||||
get_base_filenames,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
has_index_file,
|
||||
is_safetensors_available,
|
||||
load_param_groups_into_optimizer,
|
||||
load_shard_state_dict,
|
||||
load_state_dict,
|
||||
load_state_dict_into_model,
|
||||
load_states_into_optimizer,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
shard_checkpoint,
|
||||
shard_model_checkpoint,
|
||||
shard_optimizer_checkpoint,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
|
||||
__all__ = ['GeneralCheckpointIO']
|
||||
|
@ -44,12 +50,30 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
# save the checkpoint
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file.
|
||||
"""
|
||||
optimizer.load_state_dict
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
optimizer.load_state_dict(checkpoint)
|
||||
# Load param_groups
|
||||
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||
if param_group_path is None:
|
||||
raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
|
||||
Lacking param group file under current directory.')
|
||||
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
|
||||
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
|
@ -59,7 +83,54 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
prefix: str,
|
||||
size_per_shard: int,
|
||||
):
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
"""
|
||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
|
||||
- A group file (pytorch_optim_group.bin) recording information of param_groups
|
||||
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
|
||||
"""
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Offload optimizer states. States are broken into shards within max_shard_size.
|
||||
state_dict = optimizer.state_dict()
|
||||
sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard)
|
||||
|
||||
# Preparing file paths and index file.
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
|
||||
# Store the information of param groups to param_group_file.
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(state_dict, group_file_path)
|
||||
|
||||
# Save shards of optimizer states.
|
||||
total_size = 0
|
||||
for idx, shard_pair in enumerate(sharded_state):
|
||||
shard, current_size = shard_pair
|
||||
shard_file = get_shard_filename(states_name, idx)
|
||||
total_size = total_size + current_size
|
||||
for param_id in shard.keys():
|
||||
index_file.append_weight_map(str(param_id), shard_file)
|
||||
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||
|
||||
# Wrap up index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
optimizer.load_state_dict(checkpoint)
|
||||
|
||||
def save_unsharded_optimizer(
|
||||
self,
|
||||
|
@ -74,7 +145,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
|
@ -89,9 +160,9 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
|
||||
# shard checkpoint
|
||||
state_dict = model.state_dict()
|
||||
state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
|
||||
state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size)
|
||||
|
||||
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
total_size = 0
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
for idx, shard_pair in enumerate(state_dict_shard):
|
||||
|
@ -128,7 +199,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
|
||||
# read checkpoint index file
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
missing_keys = []
|
||||
|
||||
for shard_file in checkpoint_files:
|
||||
|
|
|
@ -111,7 +111,7 @@ class CheckpointIndexFile:
|
|||
return True
|
||||
return False
|
||||
|
||||
def get_checkpoint_fileanames(self) -> List[str]:
|
||||
def get_checkpoint_filenames(self) -> List[str]:
|
||||
"""
|
||||
Get the set of checkpoint filenames in the weight map.
|
||||
|
||||
|
@ -159,6 +159,18 @@ class CheckpointIndexFile:
|
|||
"""
|
||||
return list(self.weight_map.keys())
|
||||
|
||||
def get_param_group_filename(self) -> Union[str, None]:
|
||||
"""
|
||||
Get the file name of param_group file if this is a checkpoint for optimizer.
|
||||
Returns:
|
||||
str: param_group file name
|
||||
"""
|
||||
filename = self.metadata.get("param_groups", None)
|
||||
if filename:
|
||||
return str(self.root_path.joinpath(filename))
|
||||
else:
|
||||
return None
|
||||
|
||||
def write_index_file(self, save_index_file):
|
||||
"""
|
||||
Write index file.
|
||||
|
|
|
@ -1,17 +1,24 @@
|
|||
# coding=utf-8
|
||||
import re
|
||||
from collections import abc as container_abcs
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
STATES_NAME = "pytorch_optim.bin"
|
||||
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||
STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
|
||||
GROUP_FILE_NAME = "pytorch_optim_group.bin"
|
||||
|
||||
# ======================================
|
||||
# General helper functions
|
||||
|
@ -81,7 +88,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
|
|||
# ======================================
|
||||
# Helper functions for saving shard file
|
||||
# ======================================
|
||||
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
given size.
|
||||
|
@ -110,6 +117,50 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It
|
|||
yield current_block, current_block_size
|
||||
|
||||
|
||||
def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""
|
||||
Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
given size.
|
||||
"""
|
||||
|
||||
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
|
||||
states = state_dict['state']
|
||||
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
|
||||
for param_id, state in states.items():
|
||||
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
|
||||
# A state might contain more than one tensors.
|
||||
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
|
||||
state_size = 0
|
||||
isDTensor = False
|
||||
for state_tensor in state.values():
|
||||
# If the states are stored as DTensors, mark isDTensor as true.
|
||||
if type(state_tensor) == DTensor:
|
||||
isDTensor = True
|
||||
state_size += calculate_tensor_size(state_tensor)
|
||||
|
||||
if not isDTensor:
|
||||
|
||||
if current_block_size + state_size > max_shard_size:
|
||||
ret_block = current_block
|
||||
ret_block_size = current_block_size
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
|
||||
current_block[param_id] = state
|
||||
current_block_size += state_size
|
||||
|
||||
if ret_block != None:
|
||||
yield ret_block, ret_block_size
|
||||
|
||||
yield current_block, current_block_size
|
||||
|
||||
|
||||
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
|
||||
"""
|
||||
load shard state dict into model
|
||||
|
@ -179,6 +230,96 @@ def load_state_dict_into_model(model: nn.Module,
|
|||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
|
||||
def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict:
|
||||
"""
|
||||
Load information of param_groups into an initialized optimizer.
|
||||
"""
|
||||
|
||||
# Load list of param_groups from given file path.
|
||||
# The params in saved_groups are in the form of integer indices.
|
||||
saved_groups = torch.load(param_group_path)
|
||||
if not isinstance(saved_groups, List):
|
||||
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
|
||||
|
||||
# The params in param_groups are in the form of pytorch tensors.
|
||||
# For more details, please view source code of Optimizer class in pytorch.
|
||||
param_groups = optimizer.param_groups
|
||||
|
||||
# Check the compatibility of saved_groups and param_groups.
|
||||
if len(param_groups) != len(saved_groups):
|
||||
raise ValueError("loaded state dict has a different number of original parameter groups")
|
||||
param_lens = (len(g['params']) for g in param_groups)
|
||||
saved_lens = (len(g['params']) for g in saved_groups)
|
||||
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
|
||||
raise ValueError("loaded state dict contains a parameter group "
|
||||
"that doesn't match the size of optimizer's group")
|
||||
|
||||
# Creating mapping from id to parameters.
|
||||
id_map = {
|
||||
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
|
||||
)), chain.from_iterable((g['params'] for g in param_groups)))
|
||||
}
|
||||
|
||||
# Update parameter groups, setting their 'params' value.
|
||||
def update_group(group, new_group):
|
||||
new_group['params'] = group['params']
|
||||
return new_group
|
||||
|
||||
updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)]
|
||||
|
||||
optimizer.__dict__.update({'param_groups': updated_groups})
|
||||
return id_map
|
||||
|
||||
|
||||
def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: dict):
|
||||
r"""Copies states from `state_dict` into an Optimizer object.
|
||||
|
||||
Args:
|
||||
optimizer(Optimizer): An initialized Optimizer object to be loaded
|
||||
state_dict(dict): a mapping from tensor index (an integer)
|
||||
to its states to be loaded (a mapping from state name to a tensor).
|
||||
id_map(dict): a mapping from tensor index (an integer)
|
||||
to its corresponding parameter (a tensor) whose states will be updated.
|
||||
"""
|
||||
|
||||
def cast(param, value, key=None):
|
||||
r"""Make a deep copy of value, casting all tensors to device of param."""
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Floating-point types are a bit special here. They are the only ones
|
||||
# that are assumed to always match the type of params.
|
||||
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
|
||||
if (key != "step"):
|
||||
if param.is_floating_point():
|
||||
value = value.to(param.dtype)
|
||||
value = value.to(param.device)
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return {k: cast(param, v, key=k) for k, v in value.items()}
|
||||
elif isinstance(value, container_abcs.Iterable):
|
||||
return type(value)(cast(param, v) for v in value)
|
||||
else:
|
||||
return value
|
||||
|
||||
# Copy state assigned to params (and cast tensors to appropriate types).
|
||||
# State that is not assigned to params is copied as is (needed for
|
||||
# backward compatibility).
|
||||
new_states = defaultdict(dict)
|
||||
for k, v in state_dict.items():
|
||||
if k in id_map:
|
||||
param = id_map[k]
|
||||
new_states[param] = cast(param, v)
|
||||
else:
|
||||
new_states[k] = v
|
||||
|
||||
optimzier.state.update(new_states)
|
||||
|
||||
|
||||
def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
|
||||
# Do the cleaning up as in src code of Pytorch.
|
||||
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
optimizer.defaults.setdefault('differentiable', False)
|
||||
|
||||
|
||||
# ======================================
|
||||
# Helper functions for saving state dict
|
||||
# ======================================
|
||||
|
@ -203,6 +344,18 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
|
|||
torch.save(state_dict, checkpoint_file_path)
|
||||
|
||||
|
||||
def save_param_groups(state_dict: dict, group_file_path: str) -> None:
|
||||
"""
|
||||
Save information of param_groups to given file path.
|
||||
|
||||
Args:
|
||||
state_dict (dict): state dict.
|
||||
group_file_path (str): path to the group file.
|
||||
"""
|
||||
param_groups = state_dict["param_groups"]
|
||||
torch.save(param_groups, group_file_path)
|
||||
|
||||
|
||||
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
|
||||
"""
|
||||
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
|
||||
|
@ -392,28 +545,44 @@ def load_state_dict(checkpoint_file_path: Path):
|
|||
return torch.load(checkpoint_file_path)
|
||||
|
||||
|
||||
def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
if variant is not None and len(variant) > 0:
|
||||
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
|
||||
if prefix is not None and len(prefix) > 0:
|
||||
splits = weights_name.split(".")
|
||||
splits = splits[:-1] + [variant] + splits[-1:]
|
||||
splits = splits[:-1] + [prefix] + splits[-1:]
|
||||
weights_name = ".".join(splits)
|
||||
|
||||
return weights_name
|
||||
|
||||
|
||||
def get_base_filenames(variant: str = None, use_safetensors: bool = False):
|
||||
def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
|
||||
"""
|
||||
generate base weight filenames
|
||||
generate base model weight filenames
|
||||
"""
|
||||
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
||||
weights_name = add_variant(weights_name, variant)
|
||||
weights_name = add_prefix(weights_name, prefix)
|
||||
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
||||
save_index_file = add_variant(save_index_file, variant)
|
||||
save_index_file = add_prefix(save_index_file, prefix)
|
||||
|
||||
return weights_name, save_index_file
|
||||
|
||||
|
||||
def get_optimizer_base_filenames(prefix: str = None):
|
||||
"""
|
||||
generate base optimizer state filenames
|
||||
"""
|
||||
states_name = STATES_NAME
|
||||
states_name = add_prefix(states_name, prefix)
|
||||
|
||||
save_index_file = STATES_INDEX_NAME
|
||||
save_index_file = add_prefix(save_index_file, prefix)
|
||||
|
||||
param_group_file = GROUP_FILE_NAME
|
||||
param_group_file = add_prefix(param_group_file, prefix)
|
||||
|
||||
return states_name, save_index_file, param_group_file
|
||||
|
||||
|
||||
def get_shard_filename(weights_name: str, idx: int):
|
||||
"""
|
||||
get shard file name
|
||||
|
|
|
@ -60,7 +60,7 @@ def test_unsharded_checkpoint(use_safetensors: bool):
|
|||
|
||||
|
||||
@pytest.mark.parametrize('use_safetensors', [True, False])
|
||||
def test_sharded_checkpoint(use_safetensors: bool):
|
||||
def test_sharded_model_checkpoint(use_safetensors: bool):
|
||||
# create a model and optimizer
|
||||
model = resnet18()
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
|
@ -100,3 +100,101 @@ def test_sharded_checkpoint(use_safetensors: bool):
|
|||
# check for model and optimizer state dict recursively
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
||||
|
||||
|
||||
def test_sharded_optimizer_checkpoint():
|
||||
|
||||
# create a model and optimizer
|
||||
model = resnet18()
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
|
||||
# create test data sample
|
||||
x = torch.randn(1, 3, 224, 224)
|
||||
|
||||
# run fwd and bwd
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# create temp directories for checkpoint
|
||||
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
optimizer_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
# save the model and optimizer
|
||||
ckpt_io = GeneralCheckpointIO()
|
||||
|
||||
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
|
||||
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
|
||||
|
||||
# create new model
|
||||
new_model = resnet18()
|
||||
new_optimizer = Adam(new_model.parameters(), lr=0.001)
|
||||
|
||||
ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True)
|
||||
ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name))
|
||||
|
||||
# check for model and optimizer state dict recursively
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
||||
|
||||
# continue running fwd and bwd
|
||||
for _ in range(5):
|
||||
y = new_model(x)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
new_optimizer.step()
|
||||
|
||||
# save the newly got optimizer
|
||||
ckpt_io.save_model(new_model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
|
||||
ckpt_io.save_optimizer(new_optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
|
||||
|
||||
# create another new model
|
||||
new_new_model = resnet18()
|
||||
new_new_optimizer = Adam(new_new_model.parameters(), lr=0.001)
|
||||
|
||||
ckpt_io.load_model(new_new_model, str(model_ckpt_dir.name), strict=True)
|
||||
ckpt_io.load_optimizer(new_new_optimizer, str(optimizer_ckpt_dir.name))
|
||||
|
||||
# check for model and optimizer state dict recursively
|
||||
check_state_dict_equal(new_model.state_dict(), new_new_model.state_dict())
|
||||
check_state_dict_equal(new_optimizer.state_dict(), new_new_optimizer.state_dict())
|
||||
|
||||
|
||||
def test_sharded_optimizer_multiple_param_groups():
|
||||
|
||||
# create a model and optimizer
|
||||
model = resnet18()
|
||||
optimizer = Adam([{'params': model.layer1.parameters()}, \
|
||||
{'params': model.layer2.parameters(), 'lr': 0.002}], lr=0.001)
|
||||
|
||||
# create test data sample
|
||||
x = torch.randn(1, 3, 224, 224)
|
||||
|
||||
# run fwd and bwd
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# create temp directories for checkpoint
|
||||
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
optimizer_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
# save the model and optimizer
|
||||
ckpt_io = GeneralCheckpointIO()
|
||||
|
||||
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
|
||||
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
|
||||
|
||||
# create new model
|
||||
new_model = resnet18()
|
||||
new_optimizer = Adam([{'params': new_model.layer1.parameters()}, \
|
||||
{'params': new_model.layer2.parameters(), 'lr': 0.002}], lr=0.001)
|
||||
|
||||
ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True)
|
||||
ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name))
|
||||
|
||||
# check for model and optimizer state dict recursively
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
||||
|
|
Loading…
Reference in New Issue