From 63ecafb1fba0ac1fa673c0394ffb701fec95f99c Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Mon, 4 Sep 2023 11:26:45 +0800 Subject: [PATCH] [checkpointio] optimize zero optim checkpoint io (#4591) * [zero] update checkpoint io to save memory * [checkpointio] add device map to save memory --- .../booster/plugin/low_level_zero_plugin.py | 51 ++++++++++++++----- .../checkpoint_io/general_checkpoint_io.py | 2 - colossalai/checkpoint_io/utils.py | 6 +-- colossalai/zero/low_level/low_level_optim.py | 6 +-- 4 files changed, 43 insertions(+), 22 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 616b218b2..6efafc56d 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -17,8 +17,13 @@ from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO from colossalai.checkpoint_io.utils import ( get_optimizer_base_filenames, get_shard_filename, + load_param_groups_into_optimizer, + load_shard_state_dict, + load_states_into_optimizer, save_param_groups, save_state_dict, + sharded_optimizer_loading_epilogue, + unwrap_optimizer, ) from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device @@ -126,19 +131,39 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): index_file_path (str): Path to the index file prefix (str): Not used. """ - super().load_sharded_optimizer(optimizer, index_file_path, prefix) - current_rank_state_dict = optimizer.optim.state_dict()['state'] - for param_idx, state in current_rank_state_dict.items(): - for k, v in state.items(): - if isinstance(v, torch.Tensor) and k != 'step': - padding_size = (self.coordinator.world_size - - v.numel() % self.coordinator.world_size) % self.coordinator.world_size - with torch.no_grad(): - v = v.flatten() - if padding_size > 0: - v = torch.nn.functional.pad(v, [0, padding_size]) - v_list = v.split(v.numel() // self.coordinator.world_size) - current_rank_state_dict[param_idx][k] = v_list[self.coordinator.rank].detach() + # If optimizer is wrapped, unwrap it. + if isinstance(optimizer, OptimizerWrapper): + optimizer = unwrap_optimizer(optimizer) + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) + + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \ + Lacking param group file under current directory.') + id_map = load_param_groups_into_optimizer(optimizer, param_group_path) + + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() + + for shard_file in checkpoint_files: + state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) + # shard state dict + for param_idx, state in state_dict.items(): + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != 'step': + padding_size = (self.coordinator.world_size - + v.numel() % self.coordinator.world_size) % self.coordinator.world_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + v_list = v.split(v.numel() // self.coordinator.world_size) + state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone() + load_states_into_optimizer(optimizer, state_dict, id_map) + + sharded_optimizer_loading_epilogue(optimizer) class LowLevelZeroModel(ModelWrapper): diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 83e4bdcc8..34210ea52 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -78,8 +78,6 @@ class GeneralCheckpointIO(CheckpointIO): for shard_file in checkpoint_files: state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) load_states_into_optimizer(optimizer, state_dict, id_map) - del state_dict - gc.collect() sharded_optimizer_loading_epilogue(optimizer) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 8837776ae..77ff7784a 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -237,7 +237,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.") return safe_load_file(checkpoint_file) else: - return torch.load(checkpoint_file) + return torch.load(checkpoint_file, map_location=torch.device('cpu')) def load_state_dict_into_model(model: nn.Module, @@ -297,7 +297,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str # Load list of param_groups from given file path. # The params in saved_groups are in the form of integer indices. - saved_groups = torch.load(param_group_path) + saved_groups = torch.load(param_group_path, map_location=torch.device('cpu')) if not isinstance(saved_groups, List): raise ValueError(f'The param_groups saved at {param_group_path} is not of List type') @@ -608,7 +608,7 @@ def load_state_dict(checkpoint_file_path: Path): else: # load with torch - return torch.load(checkpoint_file_path) + return torch.load(checkpoint_file_path, map_location=torch.device('cpu')) def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str: diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 96d5902e8..b4439ab19 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -307,7 +307,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # or got a grad of param from another group # after reduction, the bucket will be empty if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \ - group_id != self._bucket_store.current_group_id: + group_id != self._bucket_store.current_group_id: self._run_reduction() padding_size = self._param_store.get_param_padding_size(param) @@ -553,11 +553,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) v_list = v.split(v.numel() // self._world_size) - device = 'cpu' if self._cpu_offload else 'cuda' - zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].to(device).detach() + zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone() self.optim.load_state_dict(zero_state_dict) - zero_state_dict = dict() def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]: """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.