From 822c3d4d66d2d74cb7c7080abed6a207602dddfd Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 16 Jun 2023 14:14:05 +0800 Subject: [PATCH] [checkpointio] sharded optimizer checkpoint for DDP plugin (#4002) --- colossalai/booster/booster.py | 49 ++++++++++++++----- colossalai/booster/plugin/torch_ddp_plugin.py | 10 ++-- .../checkpoint_io/checkpoint_io_base.py | 8 +-- .../checkpoint_io/general_checkpoint_io.py | 10 +++- colossalai/checkpoint_io/utils.py | 16 +++++- .../test_torch_ddp_checkpoint_io.py | 20 ++++---- 6 files changed, 79 insertions(+), 34 deletions(-) diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 6e480d0db..cee547b33 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -9,6 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from colossalai.checkpoint_io import GeneralCheckpointIO +from colossalai.interface import ModelWrapper from .accelerator import Accelerator from .mixed_precision import MixedPrecision, mixed_precision_factory @@ -165,11 +166,11 @@ class Booster: assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' return self.plugin.no_sync(model) - def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True): """Load model from checkpoint. Args: - model (nn.Module): A model boosted by Booster. + model (nn.Module or ModelWrapper): A model boosted by Booster. checkpoint (str): Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path. strict (bool, optional): whether to strictly enforce that the keys @@ -179,24 +180,34 @@ class Booster: self.checkpoint_io.load_model(model, checkpoint, strict) def save_model(self, - model: nn.Module, + model: Union[nn.Module, ModelWrapper], checkpoint: str, - prefix: str = None, shard: bool = False, - size_per_shard: int = 1024): + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False): """Save model to checkpoint. Args: - model (nn.Module): A model boosted by Booster. + model (nn.Module or ModelWrapper): A model boosted by Booster. checkpoint (str): Path to the checkpoint. It must be a local path. It is a file path if ``shard=False``. Otherwise, it is a directory path. - prefix (str, optional): A prefix added to parameter and buffer - names to compose the keys in state_dict. Defaults to None. shard (bool, optional): Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False. + gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True. + prefix (str, optional): A prefix added to parameter and buffer + names to compose the keys in state_dict. Defaults to None. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. + use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved. """ - self.checkpoint_io.save_model(model, checkpoint=checkpoint, shard=shard, size_per_shard=size_per_shard) + self.checkpoint_io.save_model(model, + checkpoint=checkpoint, + shard=shard, + gather_dtensor=gather_dtensor, + prefix=prefix, + size_per_shard=size_per_shard, + use_safetensors=use_safetensors) def load_optimizer(self, optimizer: Optimizer, checkpoint: str): """Load optimizer from checkpoint. @@ -205,12 +216,21 @@ class Booster: optimizer (Optimizer): An optimizer boosted by Booster. checkpoint (str): Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path. + prefix (str, optional): A prefix added to parameter and buffer + names to compose the keys in state_dict. Defaults to None. + size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. """ self.checkpoint_io.load_optimizer(optimizer, checkpoint) - def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024): - """Save optimizer to checkpoint. - Warning: Saving sharded optimizer checkpoint is not supported yet. + def save_optimizer(self, + optimizer: Optimizer, + checkpoint: str, + shard: bool = False, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024): + """ + Save optimizer to checkpoint. Args: optimizer (Optimizer): An optimizer boosted by Booster. @@ -218,9 +238,12 @@ class Booster: It is a file path if ``shard=False``. Otherwise, it is a directory path. shard (bool, optional): Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. + prefix (str, optional): A prefix added to parameter and buffer + names to compose the keys in state_dict. Defaults to None. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. """ - self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard) + self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard) def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """Save lr scheduler to checkpoint. diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 4bfd61af3..71b435155 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -52,7 +52,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): def save_sharded_model(self, model: nn.Module, checkpoint_path: str, - gather_dtensor: bool = False, + gather_dtensor: bool = True, prefix: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): @@ -62,8 +62,12 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): if self.coordinator.is_master(): super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors) - def save_sharded_optimier(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, - size_per_shard: int): + def save_sharded_optimizer(self, + optimizer: Optimizer, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024): """ Save optimizer to checkpoint but only on master process. """ diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 9d513043f..8ff9d87c2 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -148,6 +148,9 @@ class CheckpointIO(ABC): Args: optimizer (Optimizer): optimizer to be loaded. checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the + prefix (str, optional): A prefix added to parameter and buffer + names to compose the keys in state_dict. Defaults to None. + size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. """ index_file_exists, index_file_path = has_index_file(checkpoint) @@ -157,7 +160,7 @@ class CheckpointIO(ABC): if index_file_exists: # the existence of index file means it is a sharded checkpoint - self.load_sharded_optimizer(optimizer, index_file_path, prefix, size_per_shard) + self.load_sharded_optimizer(optimizer, index_file_path, prefix) else: self.load_unsharded_optimizer(optimizer, checkpoint) @@ -251,7 +254,7 @@ class CheckpointIO(ABC): # ======================================================== @abstractmethod - def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int): + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): """ Load optimizer from sharded checkpoint. @@ -259,7 +262,6 @@ class CheckpointIO(ABC): optimizer (Optimizer): optimizer to be loaded. index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. prefix (str): prefix for the optimizer checkpoint. - size_per_shard (int): size per shard in MB. """ pass diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index d8e133313..26cafcada 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -8,6 +8,8 @@ from typing import Iterator, Optional, OrderedDict, Tuple import torch.nn as nn from torch.optim import Optimizer +from colossalai.interface import OptimizerWrapper + from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile from .utils import ( @@ -50,11 +52,15 @@ class GeneralCheckpointIO(CheckpointIO): # save the checkpoint save_state_dict(state_dict, checkpoint, use_safetensors) - def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int): + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): """ Load sharded optimizer with the given path to index file. """ - optimizer.load_state_dict + + # If optimizer is wrapped, unwrap it. + if isinstance(optimizer, OptimizerWrapper): + optimizer = optimizer.optim + # Read checkpoint index file. ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 21b70343b..3dada00cd 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -139,6 +139,12 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> state_size = 0 isDTensor = False for state_tensor in state.values(): + + # When state_tensor is None (e.g., a SGD optimizer with momentum set to 0), + # The calculation of tensor size should be skipped to avoid error. + if state_tensor is None: + continue + # If the states are stored as DTensors, mark isDTensor as true. if type(state_tensor) == DTensor: isDTensor = True @@ -271,7 +277,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str return id_map -def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: dict): +def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict): r"""Copies states from `state_dict` into an Optimizer object. Args: @@ -311,10 +317,16 @@ def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: d else: new_states[k] = v - optimzier.state.update(new_states) + optimizer.state.update(new_states) def sharded_optimizer_loading_epilogue(optimizer: Optimizer): + r"""Do the cleaning up work after state_dict has been loaded into optimizer + + Args: + optimizer(Optimizer): An optimizer object whose state has just been loaded. + """ + # Do the cleaning up as in src code of Pytorch. optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle. optimizer.defaults.setdefault('differentiable', False) diff --git a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py index 5501ee4e3..14332b5b3 100644 --- a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py @@ -13,7 +13,8 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_ad @parameterize('shard', [True, False]) -def check_torch_ddp_checkpointIO(shard: bool): +@parameterize('size_per_shard', [16, 128]) +def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int): plugin = TorchDDPPlugin() booster = Booster(plugin=plugin) model = resnet18() @@ -38,11 +39,9 @@ def check_torch_ddp_checkpointIO(shard: bool): model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler" - booster.save_model(model, model_ckpt_path, shard=shard) - if not shard: - # TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint - booster.save_optimizer(optimizer, optimizer_ckpt_path) - booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path) + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path) dist.barrier() new_model = resnet18() @@ -55,11 +54,10 @@ def check_torch_ddp_checkpointIO(shard: bool): booster.load_model(new_model, model_ckpt_path) check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) - if not shard: - booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) - booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path) - check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False) + booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path) + check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False) def run_dist(rank, world_size, port):