mirror of https://github.com/hpcaitech/ColossalAI
[checkpointio] sharded optimizer checkpoint for DDP plugin (#4002)
parent
725af3eeeb
commit
822c3d4d66
|
@ -9,6 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||
from colossalai.interface import ModelWrapper
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
||||
|
@ -165,11 +166,11 @@ class Booster:
|
|||
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
||||
return self.plugin.no_sync(model)
|
||||
|
||||
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
|
||||
"""Load model from checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): A model boosted by Booster.
|
||||
model (nn.Module or ModelWrapper): A model boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
||||
strict (bool, optional): whether to strictly enforce that the keys
|
||||
|
@ -179,24 +180,34 @@ class Booster:
|
|||
self.checkpoint_io.load_model(model, checkpoint, strict)
|
||||
|
||||
def save_model(self,
|
||||
model: nn.Module,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
prefix: str = None,
|
||||
shard: bool = False,
|
||||
size_per_shard: int = 1024):
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""Save model to checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): A model boosted by Booster.
|
||||
model (nn.Module or ModelWrapper): A model boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
||||
prefix (str, optional): A prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
shard (bool, optional): Whether to save checkpoint a sharded way.
|
||||
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
||||
gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True.
|
||||
prefix (str, optional): A prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
|
||||
"""
|
||||
self.checkpoint_io.save_model(model, checkpoint=checkpoint, shard=shard, size_per_shard=size_per_shard)
|
||||
self.checkpoint_io.save_model(model,
|
||||
checkpoint=checkpoint,
|
||||
shard=shard,
|
||||
gather_dtensor=gather_dtensor,
|
||||
prefix=prefix,
|
||||
size_per_shard=size_per_shard,
|
||||
use_safetensors=use_safetensors)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
"""Load optimizer from checkpoint.
|
||||
|
@ -205,12 +216,21 @@ class Booster:
|
|||
optimizer (Optimizer): An optimizer boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
||||
prefix (str, optional): A prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||
"""
|
||||
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
|
||||
"""Save optimizer to checkpoint.
|
||||
Warning: Saving sharded optimizer checkpoint is not supported yet.
|
||||
def save_optimizer(self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
shard: bool = False,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024):
|
||||
"""
|
||||
Save optimizer to checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): An optimizer boosted by Booster.
|
||||
|
@ -218,9 +238,12 @@ class Booster:
|
|||
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
||||
shard (bool, optional): Whether to save checkpoint a sharded way.
|
||||
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
||||
prefix (str, optional): A prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||
"""
|
||||
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard)
|
||||
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""Save lr scheduler to checkpoint.
|
||||
|
|
|
@ -52,7 +52,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||
def save_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = False,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
|
@ -62,8 +62,12 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||
if self.coordinator.is_master():
|
||||
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
|
||||
|
||||
def save_sharded_optimier(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
|
||||
size_per_shard: int):
|
||||
def save_sharded_optimizer(self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
|
|
|
@ -148,6 +148,9 @@ class CheckpointIO(ABC):
|
|||
Args:
|
||||
optimizer (Optimizer): optimizer to be loaded.
|
||||
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
|
||||
prefix (str, optional): A prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||
"""
|
||||
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||
|
||||
|
@ -157,7 +160,7 @@ class CheckpointIO(ABC):
|
|||
|
||||
if index_file_exists:
|
||||
# the existence of index file means it is a sharded checkpoint
|
||||
self.load_sharded_optimizer(optimizer, index_file_path, prefix, size_per_shard)
|
||||
self.load_sharded_optimizer(optimizer, index_file_path, prefix)
|
||||
else:
|
||||
self.load_unsharded_optimizer(optimizer, checkpoint)
|
||||
|
||||
|
@ -251,7 +254,7 @@ class CheckpointIO(ABC):
|
|||
# ========================================================
|
||||
|
||||
@abstractmethod
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||
"""
|
||||
Load optimizer from sharded checkpoint.
|
||||
|
||||
|
@ -259,7 +262,6 @@ class CheckpointIO(ABC):
|
|||
optimizer (Optimizer): optimizer to be loaded.
|
||||
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||
prefix (str): prefix for the optimizer checkpoint.
|
||||
size_per_shard (int): size per shard in MB.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
|
@ -8,6 +8,8 @@ from typing import Iterator, Optional, OrderedDict, Tuple
|
|||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
|
||||
from .checkpoint_io_base import CheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
from .utils import (
|
||||
|
@ -50,11 +52,15 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
# save the checkpoint
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file.
|
||||
"""
|
||||
optimizer.load_state_dict
|
||||
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = optimizer.optim
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||
|
||||
|
|
|
@ -139,6 +139,12 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
|
|||
state_size = 0
|
||||
isDTensor = False
|
||||
for state_tensor in state.values():
|
||||
|
||||
# When state_tensor is None (e.g., a SGD optimizer with momentum set to 0),
|
||||
# The calculation of tensor size should be skipped to avoid error.
|
||||
if state_tensor is None:
|
||||
continue
|
||||
|
||||
# If the states are stored as DTensors, mark isDTensor as true.
|
||||
if type(state_tensor) == DTensor:
|
||||
isDTensor = True
|
||||
|
@ -271,7 +277,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
|
|||
return id_map
|
||||
|
||||
|
||||
def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: dict):
|
||||
def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict):
|
||||
r"""Copies states from `state_dict` into an Optimizer object.
|
||||
|
||||
Args:
|
||||
|
@ -311,10 +317,16 @@ def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: d
|
|||
else:
|
||||
new_states[k] = v
|
||||
|
||||
optimzier.state.update(new_states)
|
||||
optimizer.state.update(new_states)
|
||||
|
||||
|
||||
def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
|
||||
r"""Do the cleaning up work after state_dict has been loaded into optimizer
|
||||
|
||||
Args:
|
||||
optimizer(Optimizer): An optimizer object whose state has just been loaded.
|
||||
"""
|
||||
|
||||
# Do the cleaning up as in src code of Pytorch.
|
||||
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
optimizer.defaults.setdefault('differentiable', False)
|
||||
|
|
|
@ -13,7 +13,8 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_ad
|
|||
|
||||
|
||||
@parameterize('shard', [True, False])
|
||||
def check_torch_ddp_checkpointIO(shard: bool):
|
||||
@parameterize('size_per_shard', [16, 128])
|
||||
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
|
||||
plugin = TorchDDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
model = resnet18()
|
||||
|
@ -38,11 +39,9 @@ def check_torch_ddp_checkpointIO(shard: bool):
|
|||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
|
||||
booster.save_model(model, model_ckpt_path, shard=shard)
|
||||
if not shard:
|
||||
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path)
|
||||
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
|
||||
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
|
||||
dist.barrier()
|
||||
|
||||
new_model = resnet18()
|
||||
|
@ -55,11 +54,10 @@ def check_torch_ddp_checkpointIO(shard: bool):
|
|||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||
|
||||
if not shard:
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
||||
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
|
||||
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
||||
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
|
||||
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
|
Loading…
Reference in New Issue