diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 2ea7593a5..5445b4a63 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -1,3 +1,5 @@ +import logging +import os import warnings from pathlib import Path from typing import Callable, Iterable, Iterator, List, Optional, Tuple @@ -25,7 +27,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader -from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils +from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils, CheckpointIndexFile from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper @@ -74,17 +76,54 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): def save_sharded_model( self, - model: nn.Module, - checkpoint: str, - gather_dtensor: bool, - prefix: Optional[str], - size_per_shard: int, - use_safetensors: bool, + model: ModelWrapper, + checkpoint_path: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, ): """ Save model to checkpoint but only on master process. """ - raise NotImplementedError("Sharded model checkpoint is not supported yet.") + assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!" + if os.path.isfile(checkpoint_path): + logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + return + + Path(checkpoint_path).mkdir(parents=True, exist_ok=True) + with FSDP.state_dict_type( + model.unwrap(), + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + ): + state_dict = model.unwrap().state_dict() + + state_dict_shard = utils.shard_model_checkpoint(state_dict, max_shard_size=size_per_shard) + + weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint_path) + + # In general cases, is_master is set to True to get the right behavior. + total_size = utils.save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=self.coordinator.is_master(), + use_safetensors=use_safetensors, + ) + + # only save the index file on the master rank + if self.coordinator.is_master(): + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + utils.save_config_file(model.unwrap(), checkpoint_path) + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) def load_sharded_model( self, @@ -97,7 +136,24 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): """ Load model to checkpoint but only on master process. """ - raise NotImplementedError("Sharded model checkpoint is not supported yet.") + assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!" + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + + if use_safetensors and not utils.is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + + # read checkpoint index file + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() + + fsdp_state_dict = {} + for shard_file in checkpoint_files: + fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors)) + + with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT): + model.unwrap().load_state_dict(fsdp_state_dict, strict=False) def save_sharded_optimizer( self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int @@ -105,13 +161,86 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): """ Save optimizer to checkpoint but only on master process. """ - raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") + assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" + + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + with FSDP.state_dict_type( + optimizer.unwrap_model().unwrap(), + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + ): + fsdp_optim_state = FSDP.full_optim_state_dict( + optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True + ) + + if self.coordinator.is_master(): + # Preparing file paths and index file. + states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + utils.save_param_groups(fsdp_optim_state, group_file_path) + + sharded_state = utils.shard_optimizer_checkpoint(fsdp_optim_state, max_shard_size=size_per_shard) + + # Save shards of optimizer states. + # In general cases, is_master is set to True to get the right behavior. + total_size = utils.save_state_dict_shards( + sharded_state_dict=sharded_state, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=self.coordinator.is_master(), + use_safetensors=False, + ) + + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int): """ Load optimizer to checkpoint but only on master process. """ - raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") + assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" + + ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) + + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError( + f"Invalid index file path {index_file_path} for an optimizer. " + "Looking param group file under current directory." + ) + + saved_param_groups = torch.load(param_group_path) + + # Load param + fsdp_optim_state = {} + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() + for shard_file in checkpoint_files: + state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False) + fsdp_optim_state.update(state_dict_shard) + + fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups) + + with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT): + fsdp_state = FSDP.optim_state_dict_to_load( + model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict + ) + optimizer.load_state_dict(fsdp_state) + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ @@ -190,7 +319,7 @@ class TorchFSDPPlugin(DPPluginBase): raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") def support_no_sync(self) -> bool: - False + return False def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: raise NotImplementedError("Torch fsdp no_sync func not supported yet.") diff --git a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py index dd41f8185..dca562a3b 100644 --- a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py @@ -10,6 +10,7 @@ from colossalai.booster import Booster if version.parse(torch.__version__) >= version.parse("1.12.0"): from colossalai.booster.plugin import TorchFSDPPlugin + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -99,6 +100,43 @@ def check_torch_fsdp_ckpt(): outputs_sec = fsdp_model(inputs) assert criterion(outputs_sec) == criterion(outputs) + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optim_ckpt_path = f"{tempdir}/optimizer" + + run_model() + + booster.save_model(fsdp_model, model_ckpt_path, shard=True) + booster.save_optimizer(optimizer, optim_ckpt_path, shard=True) + + full_msd = fsdp_model.unwrap().state_dict() + full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer) + + import copy + sharded_osd = copy.deepcopy(full_osd) + + run_model() + + full_msd_updated = fsdp_model.unwrap().state_dict() + full_osd_updated = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer) + + # cost much time led to timeout + # assert not compare_nested_dict(full_osd_updated, sharded_osd) + # assert not compare_nested_dict(full_msd_updated, full_msd) + outputs_first = fsdp_model(inputs) + assert criterion(outputs_first) != criterion(outputs) + + booster.load_model(fsdp_model, model_ckpt_path) + booster.load_optimizer(optimizer, optim_ckpt_path) + + full_msd_restore = fsdp_model.unwrap().state_dict() + sharded_osd_restore = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer) + + assert compare_nested_dict(sharded_osd, sharded_osd_restore) + assert compare_nested_dict(full_msd_restore, full_msd) + outputs_sec = fsdp_model(inputs) + assert criterion(outputs_sec) == criterion(outputs) + def run_dist(rank, world_size, port): # init dist env