Browse Source

[checkpointio] sharded optimizer checkpoint for DDP plugin (#4002)

pull/4026/head
Baizhou Zhang 1 year ago committed by GitHub
parent
commit
822c3d4d66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 49
      colossalai/booster/booster.py
  2. 10
      colossalai/booster/plugin/torch_ddp_plugin.py
  3. 8
      colossalai/checkpoint_io/checkpoint_io_base.py
  4. 10
      colossalai/checkpoint_io/general_checkpoint_io.py
  5. 16
      colossalai/checkpoint_io/utils.py
  6. 20
      tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py

49
colossalai/booster/booster.py

@ -9,6 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper
from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
@ -165,11 +166,11 @@ class Booster:
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model)
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
"""Load model from checkpoint.
Args:
model (nn.Module): A model boosted by Booster.
model (nn.Module or ModelWrapper): A model boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
strict (bool, optional): whether to strictly enforce that the keys
@ -179,24 +180,34 @@ class Booster:
self.checkpoint_io.load_model(model, checkpoint, strict)
def save_model(self,
model: nn.Module,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
prefix: str = None,
shard: bool = False,
size_per_shard: int = 1024):
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False):
"""Save model to checkpoint.
Args:
model (nn.Module): A model boosted by Booster.
model (nn.Module or ModelWrapper): A model boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It is a file path if ``shard=False``. Otherwise, it is a directory path.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
shard (bool, optional): Whether to save checkpoint a sharded way.
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
"""
self.checkpoint_io.save_model(model, checkpoint=checkpoint, shard=shard, size_per_shard=size_per_shard)
self.checkpoint_io.save_model(model,
checkpoint=checkpoint,
shard=shard,
gather_dtensor=gather_dtensor,
prefix=prefix,
size_per_shard=size_per_shard,
use_safetensors=use_safetensors)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""Load optimizer from checkpoint.
@ -205,12 +216,21 @@ class Booster:
optimizer (Optimizer): An optimizer boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
"""Save optimizer to checkpoint.
Warning: Saving sharded optimizer checkpoint is not supported yet.
def save_optimizer(self,
optimizer: Optimizer,
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024):
"""
Save optimizer to checkpoint.
Args:
optimizer (Optimizer): An optimizer boosted by Booster.
@ -218,9 +238,12 @@ class Booster:
It is a file path if ``shard=False``. Otherwise, it is a directory path.
shard (bool, optional): Whether to save checkpoint a sharded way.
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard)
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""Save lr scheduler to checkpoint.

10
colossalai/booster/plugin/torch_ddp_plugin.py

@ -52,7 +52,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = False,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
@ -62,8 +62,12 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master():
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
def save_sharded_optimier(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
size_per_shard: int):
def save_sharded_optimizer(self,
optimizer: Optimizer,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024):
"""
Save optimizer to checkpoint but only on master process.
"""

8
colossalai/checkpoint_io/checkpoint_io_base.py

@ -148,6 +148,9 @@ class CheckpointIO(ABC):
Args:
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
index_file_exists, index_file_path = has_index_file(checkpoint)
@ -157,7 +160,7 @@ class CheckpointIO(ABC):
if index_file_exists:
# the existence of index file means it is a sharded checkpoint
self.load_sharded_optimizer(optimizer, index_file_path, prefix, size_per_shard)
self.load_sharded_optimizer(optimizer, index_file_path, prefix)
else:
self.load_unsharded_optimizer(optimizer, checkpoint)
@ -251,7 +254,7 @@ class CheckpointIO(ABC):
# ========================================================
@abstractmethod
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
"""
Load optimizer from sharded checkpoint.
@ -259,7 +262,6 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded.
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
"""
pass

10
colossalai/checkpoint_io/general_checkpoint_io.py

@ -8,6 +8,8 @@ from typing import Iterator, Optional, OrderedDict, Tuple
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
@ -50,11 +52,15 @@ class GeneralCheckpointIO(CheckpointIO):
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
"""
Load sharded optimizer with the given path to index file.
"""
optimizer.load_state_dict
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.optim
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)

16
colossalai/checkpoint_io/utils.py

@ -139,6 +139,12 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
state_size = 0
isDTensor = False
for state_tensor in state.values():
# When state_tensor is None (e.g., a SGD optimizer with momentum set to 0),
# The calculation of tensor size should be skipped to avoid error.
if state_tensor is None:
continue
# If the states are stored as DTensors, mark isDTensor as true.
if type(state_tensor) == DTensor:
isDTensor = True
@ -271,7 +277,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
return id_map
def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: dict):
def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict):
r"""Copies states from `state_dict` into an Optimizer object.
Args:
@ -311,10 +317,16 @@ def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: d
else:
new_states[k] = v
optimzier.state.update(new_states)
optimizer.state.update(new_states)
def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
r"""Do the cleaning up work after state_dict has been loaded into optimizer
Args:
optimizer(Optimizer): An optimizer object whose state has just been loaded.
"""
# Do the cleaning up as in src code of Pytorch.
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
optimizer.defaults.setdefault('differentiable', False)

20
tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py

@ -13,7 +13,8 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_ad
@parameterize('shard', [True, False])
def check_torch_ddp_checkpointIO(shard: bool):
@parameterize('size_per_shard', [16, 128])
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
@ -38,11 +39,9 @@ def check_torch_ddp_checkpointIO(shard: bool):
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
booster.save_model(model, model_ckpt_path, shard=shard)
if not shard:
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
booster.save_optimizer(optimizer, optimizer_ckpt_path)
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
dist.barrier()
new_model = resnet18()
@ -55,11 +54,10 @@ def check_torch_ddp_checkpointIO(shard: bool):
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
if not shard:
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
def run_dist(rank, world_size, port):

Loading…
Cancel
Save