From 6b305a99d6bf33cb2f6fdeb4b5e099cd026cf485 Mon Sep 17 00:00:00 2001 From: wukong1992 Date: Tue, 23 May 2023 16:58:45 +0800 Subject: [PATCH] [booster] torch fsdp fix ckpt (#3788) --- colossalai/booster/booster.py | 2 +- .../booster/plugin/torch_fsdp_plugin.py | 219 ++++++------------ .../checkpoint_io/general_checkpoint_io.py | 77 +++--- .../test_plugin/test_torch_fsdp_plugin.py | 5 +- .../test_torch_fsdp_checkpoint_io.py | 113 +++++++++ 5 files changed, 230 insertions(+), 186 deletions(-) create mode 100644 tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index be9c1c9dc..61d912157 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -196,7 +196,7 @@ class Booster: If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. """ - self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard) + self.checkpoint_io.save_model(model, checkpoint=checkpoint, shard=shard, size_per_shard=size_per_shard) def load_optimizer(self, optimizer: Optimizer, checkpoint: str): """Load optimizer from checkpoint. diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 0daefa9ff..340555dc6 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union import torch @@ -5,30 +6,18 @@ import torch.nn as nn from packaging import version from torch.distributed import ProcessGroup -if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( - torch.__version__) < version.parse('2.0.0'): + +if version.parse(torch.__version__) >= version.parse('1.12.0'): from torch.distributed.fsdp import FullStateDictConfig from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.fsdp.fully_sharded_data_parallel import ( BackwardPrefetch, CPUOffload, - MixedPrecision, - ShardingStrategy, - ) -elif version.parse(torch.__version__) >= version.parse('2.0.0'): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from torch.distributed.fsdp._init_utils import ProcessGroupType - from torch.distributed.fsdp.api import ( - BackwardPrefetch, - CPUOffload, - FullOptimStateDictConfig, FullStateDictConfig, MixedPrecision, ShardingStrategy, - StateDictType, ) - from torch.distributed.fsdp.wrap import _FSDPPolicy else: raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") @@ -36,7 +25,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader -from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO +from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper @@ -51,102 +40,71 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): super().__init__() self.coordinator = DistCoordinator() - def __set_model_optim_state( - self, - model, - state_dict_type, - state_dict_config, - optim_state_dict_config, - ): - return FSDP.set_state_dict_type(model, state_dict_type, state_dict_config, optim_state_dict_config) + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): + checkpoint = utils.load_state_dict(checkpoint) + model.load_state_dict(checkpoint) - def load_sharded_model(self, model: nn.Module, checkpoint: str): + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + checkpoint = utils.load_state_dict(checkpoint) + fsdp_model = optimizer.unwrap_model() + sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model) + optimizer.load_state_dict(sharded_osd) - # TODO(jishaomin): implement this method as it can be supported by Huggingface model - raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.") - - def load_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str): - - # TODO(jishaomin): implement this method as it can be supported by Huggingface model - raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.") - - def save_sharded_model(self, model: nn.Module, checkpoint: str): - - # TODO(jishaomin): implement this method as it can be supported by Huggingface model - raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.") - - def save_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str): - - # TODO(jishaomin): implement this method as it can be supported by Huggingface model - raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.") - - def load_unsharded_model(self, model: nn.Module, checkpoint: str): - """ - Load model from checkpoint with automatic unwrapping. - """ - # the model should be unwrapped in self.load_model via ModelWrapper.unwrap - - if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( - torch.__version__) < version.parse('2.0.0'): - full_state_dict = self.load_state_dict(checkpoint) - elif version.parse(torch.__version__) >= version.parse('2.0.0'): - full_state_dict = self.load_state_dict(checkpoint) - self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True)) - full_state_dict = model.state_dict() - else: - raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") - - model.load_state_dict(full_state_dict) - - def load_unsharded_optimizer(self, model: nn.Module, optim: Optimizer, checkpoint: str): - """ - Load Optimizer from checkpoint with automatic unwrapping. - """ - - if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( - torch.__version__) < version.parse('2.0.0'): - optim_full_state_dict = self.load_state_dict(checkpoint) - elif version.parse(torch.__version__) >= version.parse('2.0.0'): - optim_full_state_dict = self.load_state_dict(checkpoint) - FSDP.full_optim_state_dict_to_load(optim_full_state_dict, model, optim) - else: - raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") - - optim.load_state_dict(optim_full_state_dict) - - def save_unsharded_model(self, model: nn.Module, checkpoint: str): + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ Save model to checkpoint but only on master process. """ # the model should be unwrapped in self.load_model via ModelWrapper.unwrap + cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): + full_model_state = model.state_dict() + utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) - if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( - torch.__version__) < version.parse('2.0.0'): - cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): - model_state_dict = model.state_dict() - elif version.parse(torch.__version__) >= version.parse('2.0.0'): - self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True)) - model_state_dict = model.state_dict() - else: - raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") - self.save_checkpoint(model_state_dict, checkpoint) - - def save_unsharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str): + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): """ Save optimizer to checkpoint but only on master process. """ + assert isinstance(optimizer, FSDPOptimizerWrapper) + fsdp_model = optimizer.unwrap_model() + full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True) + utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) - if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( - torch.__version__) < version.parse('2.0.0'): - optim_state_dict = FSDP.full_optim_state_dict(model=model, optim=optimizer) - elif version.parse(torch.__version__) >= version.parse('2.0.0'): - self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, - FullOptimStateDictConfig(rank0_only=True)) - optim_state_dict = FSDP.optim_state_dict(model, optimizer) - else: - raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") - self.save_checkpoint(optim_state_dict, checkpoint) + def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str], + size_per_shard: int, use_safetensors: bool): + """ + Save model to checkpoint but only on master process. + """ + raise NotImplementedError("Sharded model checkpoint is not supported yet.") + + def load_sharded_model(self, + model: nn.Module, + checkpoint_index_file: Path, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True): + """ + Load model to checkpoint but only on master process. + """ + raise NotImplementedError("Sharded model checkpoint is not supported yet.") + + def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + """ + Save optimizer to checkpoint but only on master process. + """ + raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") + + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int): + """ + Load optimizer to checkpoint but only on master process. + """ + raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save model to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) class TorchFSDPModel(ModelWrapper): @@ -156,7 +114,17 @@ class TorchFSDPModel(ModelWrapper): self.module = FSDP(module, *args, **kwargs) def unwrap(self): - return self.module.module + return self.module + + +class FSDPOptimizerWrapper(OptimizerWrapper): + + def __init__(self, optimizer: Optimizer, model: nn.Module): + self.model = model + super().__init__(optimizer) + + def unwrap_model(self) -> nn.Module: + return self.model class TorchFSDPPlugin(DPPluginBase): @@ -178,8 +146,7 @@ class TorchFSDPPlugin(DPPluginBase): See https://pytorch.org/docs/stable/fsdp.html for details. """ - if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( - torch.__version__) < version.parse('2.0.0'): + if version.parse(torch.__version__) >= version.parse('1.12.0'): def __init__( self, @@ -191,7 +158,6 @@ class TorchFSDPPlugin(DPPluginBase): mixed_precision: Optional[MixedPrecision] = None, ignored_modules: Optional[Iterable[torch.nn.Module]] = None, param_init_fn: Optional[Callable[[nn.Module], None]] = None, - device_id: Optional[Union[int, torch.device]] = None, sync_module_states: bool = False, ): super().__init__() @@ -203,42 +169,7 @@ class TorchFSDPPlugin(DPPluginBase): mixed_precision=mixed_precision, ignored_modules=ignored_modules, param_init_fn=param_init_fn, - device_id=device_id, sync_module_states=sync_module_states) - elif version.parse(torch.__version__) >= version.parse('2.0.0'): - - def __init__( - self, - process_group: ProcessGroupType = None, - sharding_strategy: Optional[ShardingStrategy] = None, - cpu_offload: Optional[CPUOffload] = None, - auto_wrap_policy: Optional[Union[Callable, _FSDPPolicy]] = None, - backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE, - mixed_precision: Optional[MixedPrecision] = None, - ignored_modules: Optional[Iterable[torch.nn.Module]] = None, - param_init_fn: Optional[Callable[[nn.Module], None]] = None, - device_id: Optional[Union[int, torch.device]] = None, - sync_module_states: bool = False, - forward_prefetch: bool = False, - limit_all_gathers: bool = False, - use_orig_params: bool = False, - ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None, - ): - super().__init__() - self.fsdp_kwargs = dict(process_group=process_group, - sharding_strategy=sharding_strategy, - cpu_offload=cpu_offload, - auto_wrap_policy=auto_wrap_policy, - backward_prefetch=backward_prefetch, - mixed_precision=mixed_precision, - ignored_modules=ignored_modules, - param_init_fn=param_init_fn, - device_id=device_id, - sync_module_states=sync_module_states, - forward_prefetch=forward_prefetch, - limit_all_gathers=limit_all_gathers, - use_orig_params=use_orig_params, - ignored_parameters=ignored_parameters) else: raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") @@ -269,14 +200,14 @@ class TorchFSDPPlugin(DPPluginBase): lr_scheduler: LRScheduler = None, ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: - model = model.cuda() # wrap the model with PyTorch FSDP - model = TorchFSDPModel(model, **self.fsdp_kwargs) + fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs) + optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) - if not isinstance(optimizer, OptimizerWrapper): - optimizer = OptimizerWrapper(optimizer) + if not isinstance(optimizer, FSDPOptimizerWrapper): + optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model) - return model, optimizer, criterion, dataloader, lr_scheduler + return fsdp_model, optimizer, criterion, dataloader, lr_scheduler def control_checkpoint_io(self) -> bool: return True diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 96a883fdb..2cc9c3faa 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -1,26 +1,26 @@ -from pathlib import Path +import gc +import logging +import os from functools import reduce +from pathlib import Path +from typing import Iterator, Optional, OrderedDict, Tuple import torch.nn as nn from torch.optim import Optimizer -import logging -import os -import gc -from typing import Optional, Iterator, OrderedDict, Tuple from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile from .utils import ( - has_index_file, - load_state_dict, - save_state_dict, - is_safetensors_available, - shard_checkpoint, - load_shard_state_dict, - load_state_dict_into_model, + get_base_filenames, get_shard_filename, - get_base_filenames - ) + has_index_file, + is_safetensors_available, + load_shard_state_dict, + load_state_dict, + load_state_dict_into_model, + save_state_dict, + shard_checkpoint, +) __all__ = ['GeneralCheckpointIO'] @@ -29,6 +29,7 @@ class GeneralCheckpointIO(CheckpointIO): """ Checkpoint IO """ + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): checkpoint = load_state_dict(checkpoint) model.load_state_dict(checkpoint, strict=strict) @@ -69,19 +70,23 @@ class GeneralCheckpointIO(CheckpointIO): # TODO(FrankLeeeee): handle distributed tensors save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) - - def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False, - variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): - """ + def save_sharded_model(self, + model: nn.Module, + checkpoint_path: str, + gather_dtensor: bool = False, + variant: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False): + """ implement this method as it can be supported by Huggingface model, save shard model, save model to multiple files """ if os.path.isfile(checkpoint_path): logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") return - + Path(checkpoint_path).mkdir(parents=True, exist_ok=True) - + # shard checkpoint state_dict = model.state_dict() state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size) @@ -95,21 +100,22 @@ class GeneralCheckpointIO(CheckpointIO): total_size = total_size + shard_pair[1] for key in shard.keys(): index_file.append_weight_map(key, shard_file) - + checkpoint_file_path = os.path.join(checkpoint_path, shard_file) save_state_dict(shard, checkpoint_file_path, use_safetensors) - + index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info( - f"The model is going to be split to checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) + logging.info(f"The model is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") - - def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, - use_safetensors: bool = False, load_sub_module: bool = True): + def load_sharded_model(self, + model: nn.Module, + checkpoint_index_file: Path, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True): """ load shard model, load model from multiple files """ @@ -119,7 +125,7 @@ class GeneralCheckpointIO(CheckpointIO): if use_safetensors and not is_safetensors_available(): raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") - + # read checkpoint index file ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames() @@ -134,10 +140,7 @@ class GeneralCheckpointIO(CheckpointIO): if strict: remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) if len(remain_keys) > 0: - error_msgs = 'Missing key(s) in state_dict: {}. '.format( - ', '.join('"{}"'.format(k) for k in missing_keys)) + error_msgs = 'Missing key(s) in state_dict: {}. '.format(', '.join( + '"{}"'.format(k) for k in missing_keys)) raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - self.__class__.__name__, "\n\t".join(error_msgs))) - - - + self.__class__.__name__, "\n\t".join(error_msgs))) diff --git a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py index 12562095c..44767f051 100644 --- a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py @@ -1,10 +1,6 @@ -from contextlib import nullcontext - import pytest import torch -import torch.distributed as dist from packaging import version -from torch import nn from torch.optim import SGD import colossalai @@ -19,6 +15,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo +# test baisc fsdp function def run_fn(model_fn, data_gen_fn, output_transform_fn): plugin = TorchFSDPPlugin() booster = Booster(plugin=plugin) diff --git a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py new file mode 100644 index 000000000..2b6090bb1 --- /dev/null +++ b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py @@ -0,0 +1,113 @@ +import pytest +import torch +from packaging import version +from torch import nn +from torch.optim import SGD +from torchvision.models import resnet18 +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster + +if version.parse(torch.__version__) >= version.parse('1.12.0'): + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from colossalai.booster.plugin import TorchFSDPPlugin + +from colossalai.testing import rerun_if_address_is_in_use, spawn, check_state_dict_equal + + +def compare_nested_dict(dict1, dict2): + for key in dict1: + if key in dict2: + if type(dict1[key]) is dict: + assert type(dict2[key]) is dict + diff = compare_nested_dict(dict1[key], dict2[key]) + if not diff: + return diff + elif type(dict1[key]) is list: + assert type(dict2[key]) is list + for i, val in enumerate(dict1[key]): + if isinstance(val, torch.Tensor): + if not torch.equal(dict1[key][i], dict2[key][i]): + return False + elif val != dict2[key][i]: + return False + elif type(dict1[key]) is torch.Tensor: + assert type(dict2[key]) is torch.Tensor + if not torch.equal(dict1[key], dict2[key]): + return False + else: + if dict1[key] != dict2[key]: + return False + else: + return False + return True + + +def check_torch_fsdp_ckpt(): + model = resnet18() + plugin = TorchFSDPPlugin() + booster = Booster(plugin=plugin) + optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9) + criterion = lambda x: x.mean() + fsdp_model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + inputs = torch.randn(4, 3, 224, 224) + outputs = None + + def run_model(): + nonlocal outputs + outputs = fsdp_model(inputs) + optimizer.zero_grad() + criterion(outputs).backward() + optimizer.step() + + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optim_ckpt_path = f"{tempdir}/optimizer" + + run_model() + + booster.save_model(fsdp_model, model_ckpt_path, shard=False) + booster.save_optimizer(optimizer, optim_ckpt_path, shard=False) + + full_msd = fsdp_model.state_dict() + #full_osd = FSDP.full_optim_state_dict(fsdp_model, optimizer) + sharded_osd = optimizer.state_dict() + import copy + sharded_osd = copy.deepcopy(sharded_osd) + + run_model() + + full_msd_updated = fsdp_model.state_dict() + #full_osd_updated = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True) + sharded_osd_updated = optimizer.state_dict() + + assert not compare_nested_dict(sharded_osd, sharded_osd_updated) + assert not compare_nested_dict(full_msd_updated, full_msd) + outputs_first = fsdp_model(inputs) + assert criterion(outputs_first) != criterion(outputs) + + booster.load_model(fsdp_model, model_ckpt_path) + booster.load_optimizer(optimizer, optim_ckpt_path) + + full_msd_restore = fsdp_model.state_dict() + #full_osd_restore = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True) + sharded_osd_restore = optimizer.state_dict() + + assert compare_nested_dict(sharded_osd, sharded_osd_restore) + assert compare_nested_dict(full_msd_restore, full_msd) + outputs_sec = fsdp_model(inputs) + assert criterion(outputs_sec) == criterion(outputs) + + +def run_dist(rank, world_size, port): + # init dist env + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + check_torch_fsdp_ckpt() + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher") +@rerun_if_address_is_in_use() +def test_torch_fsdp_ckpt(): + spawn(run_dist, 2)