mirror of https://github.com/hpcaitech/ColossalAI
[booster] torch fsdp fix ckpt (#3788)
parent
9265f2d4d7
commit
6b305a99d6
|
@ -196,7 +196,7 @@ class Booster:
|
|||
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||
"""
|
||||
self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard)
|
||||
self.checkpoint_io.save_model(model, checkpoint=checkpoint, shard=shard, size_per_shard=size_per_shard)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
"""Load optimizer from checkpoint.
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
@ -5,30 +6,18 @@ import torch.nn as nn
|
|||
from packaging import version
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||
torch.__version__) < version.parse('2.0.0'):
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
from torch.distributed.fsdp import FullStateDictConfig
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import StateDictType
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
BackwardPrefetch,
|
||||
CPUOffload,
|
||||
MixedPrecision,
|
||||
ShardingStrategy,
|
||||
)
|
||||
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp._init_utils import ProcessGroupType
|
||||
from torch.distributed.fsdp.api import (
|
||||
BackwardPrefetch,
|
||||
CPUOffload,
|
||||
FullOptimStateDictConfig,
|
||||
FullStateDictConfig,
|
||||
MixedPrecision,
|
||||
ShardingStrategy,
|
||||
StateDictType,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import _FSDPPolicy
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
|
||||
|
@ -36,7 +25,7 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
||||
|
@ -51,102 +40,71 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
super().__init__()
|
||||
self.coordinator = DistCoordinator()
|
||||
|
||||
def __set_model_optim_state(
|
||||
self,
|
||||
model,
|
||||
state_dict_type,
|
||||
state_dict_config,
|
||||
optim_state_dict_config,
|
||||
):
|
||||
return FSDP.set_state_dict_type(model, state_dict_type, state_dict_config, optim_state_dict_config)
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
||||
checkpoint = utils.load_state_dict(checkpoint)
|
||||
model.load_state_dict(checkpoint)
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint: str):
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
checkpoint = utils.load_state_dict(checkpoint)
|
||||
fsdp_model = optimizer.unwrap_model()
|
||||
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
|
||||
optimizer.load_state_dict(sharded_osd)
|
||||
|
||||
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
|
||||
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
|
||||
|
||||
def load_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
|
||||
|
||||
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
|
||||
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
|
||||
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str):
|
||||
|
||||
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
|
||||
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
|
||||
|
||||
def save_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
|
||||
|
||||
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
|
||||
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str):
|
||||
"""
|
||||
Load model from checkpoint with automatic unwrapping.
|
||||
"""
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||
torch.__version__) < version.parse('2.0.0'):
|
||||
full_state_dict = self.load_state_dict(checkpoint)
|
||||
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
full_state_dict = self.load_state_dict(checkpoint)
|
||||
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True))
|
||||
full_state_dict = model.state_dict()
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
|
||||
model.load_state_dict(full_state_dict)
|
||||
|
||||
def load_unsharded_optimizer(self, model: nn.Module, optim: Optimizer, checkpoint: str):
|
||||
"""
|
||||
Load Optimizer from checkpoint with automatic unwrapping.
|
||||
"""
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||
torch.__version__) < version.parse('2.0.0'):
|
||||
optim_full_state_dict = self.load_state_dict(checkpoint)
|
||||
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
optim_full_state_dict = self.load_state_dict(checkpoint)
|
||||
FSDP.full_optim_state_dict_to_load(optim_full_state_dict, model, optim)
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
|
||||
optim.load_state_dict(optim_full_state_dict)
|
||||
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str):
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
|
||||
full_model_state = model.state_dict()
|
||||
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||
torch.__version__) < version.parse('2.0.0'):
|
||||
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
|
||||
model_state_dict = model.state_dict()
|
||||
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True))
|
||||
model_state_dict = model.state_dict()
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
self.save_checkpoint(model_state_dict, checkpoint)
|
||||
|
||||
def save_unsharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
assert isinstance(optimizer, FSDPOptimizerWrapper)
|
||||
fsdp_model = optimizer.unwrap_model()
|
||||
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
|
||||
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||
torch.__version__) < version.parse('2.0.0'):
|
||||
optim_state_dict = FSDP.full_optim_state_dict(model=model, optim=optimizer)
|
||||
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT,
|
||||
FullOptimStateDictConfig(rank0_only=True))
|
||||
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
self.save_checkpoint(optim_state_dict, checkpoint)
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
|
||||
size_per_shard: int, use_safetensors: bool):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
||||
|
||||
def load_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True):
|
||||
"""
|
||||
Load model to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
||||
|
||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
|
||||
"""
|
||||
Load optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
|
||||
class TorchFSDPModel(ModelWrapper):
|
||||
|
@ -156,7 +114,17 @@ class TorchFSDPModel(ModelWrapper):
|
|||
self.module = FSDP(module, *args, **kwargs)
|
||||
|
||||
def unwrap(self):
|
||||
return self.module.module
|
||||
return self.module
|
||||
|
||||
|
||||
class FSDPOptimizerWrapper(OptimizerWrapper):
|
||||
|
||||
def __init__(self, optimizer: Optimizer, model: nn.Module):
|
||||
self.model = model
|
||||
super().__init__(optimizer)
|
||||
|
||||
def unwrap_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
|
||||
class TorchFSDPPlugin(DPPluginBase):
|
||||
|
@ -178,8 +146,7 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
See https://pytorch.org/docs/stable/fsdp.html for details.
|
||||
"""
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||
torch.__version__) < version.parse('2.0.0'):
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -191,7 +158,6 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
mixed_precision: Optional[MixedPrecision] = None,
|
||||
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
|
||||
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
|
||||
device_id: Optional[Union[int, torch.device]] = None,
|
||||
sync_module_states: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -203,42 +169,7 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
mixed_precision=mixed_precision,
|
||||
ignored_modules=ignored_modules,
|
||||
param_init_fn=param_init_fn,
|
||||
device_id=device_id,
|
||||
sync_module_states=sync_module_states)
|
||||
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_group: ProcessGroupType = None,
|
||||
sharding_strategy: Optional[ShardingStrategy] = None,
|
||||
cpu_offload: Optional[CPUOffload] = None,
|
||||
auto_wrap_policy: Optional[Union[Callable, _FSDPPolicy]] = None,
|
||||
backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE,
|
||||
mixed_precision: Optional[MixedPrecision] = None,
|
||||
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
|
||||
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
|
||||
device_id: Optional[Union[int, torch.device]] = None,
|
||||
sync_module_states: bool = False,
|
||||
forward_prefetch: bool = False,
|
||||
limit_all_gathers: bool = False,
|
||||
use_orig_params: bool = False,
|
||||
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.fsdp_kwargs = dict(process_group=process_group,
|
||||
sharding_strategy=sharding_strategy,
|
||||
cpu_offload=cpu_offload,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
backward_prefetch=backward_prefetch,
|
||||
mixed_precision=mixed_precision,
|
||||
ignored_modules=ignored_modules,
|
||||
param_init_fn=param_init_fn,
|
||||
device_id=device_id,
|
||||
sync_module_states=sync_module_states,
|
||||
forward_prefetch=forward_prefetch,
|
||||
limit_all_gathers=limit_all_gathers,
|
||||
use_orig_params=use_orig_params,
|
||||
ignored_parameters=ignored_parameters)
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
|
||||
|
@ -269,14 +200,14 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
lr_scheduler: LRScheduler = None,
|
||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||
|
||||
model = model.cuda()
|
||||
# wrap the model with PyTorch FSDP
|
||||
model = TorchFSDPModel(model, **self.fsdp_kwargs)
|
||||
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
|
||||
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
|
||||
|
||||
if not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = OptimizerWrapper(optimizer)
|
||||
if not isinstance(optimizer, FSDPOptimizerWrapper):
|
||||
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
return fsdp_model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
return True
|
||||
|
|
|
@ -1,26 +1,26 @@
|
|||
from pathlib import Path
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
import logging
|
||||
import os
|
||||
import gc
|
||||
from typing import Optional, Iterator, OrderedDict, Tuple
|
||||
|
||||
from .checkpoint_io_base import CheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
from .utils import (
|
||||
has_index_file,
|
||||
load_state_dict,
|
||||
save_state_dict,
|
||||
is_safetensors_available,
|
||||
shard_checkpoint,
|
||||
load_shard_state_dict,
|
||||
load_state_dict_into_model,
|
||||
get_base_filenames,
|
||||
get_shard_filename,
|
||||
get_base_filenames
|
||||
)
|
||||
has_index_file,
|
||||
is_safetensors_available,
|
||||
load_shard_state_dict,
|
||||
load_state_dict,
|
||||
load_state_dict_into_model,
|
||||
save_state_dict,
|
||||
shard_checkpoint,
|
||||
)
|
||||
|
||||
__all__ = ['GeneralCheckpointIO']
|
||||
|
||||
|
@ -29,6 +29,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
"""
|
||||
Checkpoint IO
|
||||
"""
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
model.load_state_dict(checkpoint, strict=strict)
|
||||
|
@ -69,19 +70,23 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
# TODO(FrankLeeeee): handle distributed tensors
|
||||
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
|
||||
|
||||
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False,
|
||||
variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False):
|
||||
"""
|
||||
def save_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
implement this method as it can be supported by Huggingface model,
|
||||
save shard model, save model to multiple files
|
||||
"""
|
||||
if os.path.isfile(checkpoint_path):
|
||||
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
|
||||
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# shard checkpoint
|
||||
state_dict = model.state_dict()
|
||||
state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
|
||||
|
@ -95,21 +100,22 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
total_size = total_size + shard_pair[1]
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
|
||||
|
||||
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
||||
|
||||
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(
|
||||
f"The model is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
logging.info(f"The model is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False,
|
||||
use_safetensors: bool = False, load_sub_module: bool = True):
|
||||
def load_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True):
|
||||
"""
|
||||
load shard model, load model from multiple files
|
||||
"""
|
||||
|
@ -119,7 +125,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
|
||||
|
||||
|
||||
# read checkpoint index file
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
|
||||
|
@ -134,10 +140,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
if strict:
|
||||
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
||||
if len(remain_keys) > 0:
|
||||
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
|
||||
', '.join('"{}"'.format(k) for k in missing_keys))
|
||||
error_msgs = 'Missing key(s) in state_dict: {}. '.format(', '.join(
|
||||
'"{}"'.format(k) for k in missing_keys))
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
|
||||
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
|
|
@ -1,10 +1,6 @@
|
|||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.optim import SGD
|
||||
|
||||
import colossalai
|
||||
|
@ -19,6 +15,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
|
|||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
# test baisc fsdp function
|
||||
def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
||||
plugin = TorchFSDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.optim import SGD
|
||||
from torchvision.models import resnet18
|
||||
from utils import shared_tempdir
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from colossalai.booster.plugin import TorchFSDPPlugin
|
||||
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn, check_state_dict_equal
|
||||
|
||||
|
||||
def compare_nested_dict(dict1, dict2):
|
||||
for key in dict1:
|
||||
if key in dict2:
|
||||
if type(dict1[key]) is dict:
|
||||
assert type(dict2[key]) is dict
|
||||
diff = compare_nested_dict(dict1[key], dict2[key])
|
||||
if not diff:
|
||||
return diff
|
||||
elif type(dict1[key]) is list:
|
||||
assert type(dict2[key]) is list
|
||||
for i, val in enumerate(dict1[key]):
|
||||
if isinstance(val, torch.Tensor):
|
||||
if not torch.equal(dict1[key][i], dict2[key][i]):
|
||||
return False
|
||||
elif val != dict2[key][i]:
|
||||
return False
|
||||
elif type(dict1[key]) is torch.Tensor:
|
||||
assert type(dict2[key]) is torch.Tensor
|
||||
if not torch.equal(dict1[key], dict2[key]):
|
||||
return False
|
||||
else:
|
||||
if dict1[key] != dict2[key]:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_torch_fsdp_ckpt():
|
||||
model = resnet18()
|
||||
plugin = TorchFSDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
|
||||
criterion = lambda x: x.mean()
|
||||
fsdp_model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
inputs = torch.randn(4, 3, 224, 224)
|
||||
outputs = None
|
||||
|
||||
def run_model():
|
||||
nonlocal outputs
|
||||
outputs = fsdp_model(inputs)
|
||||
optimizer.zero_grad()
|
||||
criterion(outputs).backward()
|
||||
optimizer.step()
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optim_ckpt_path = f"{tempdir}/optimizer"
|
||||
|
||||
run_model()
|
||||
|
||||
booster.save_model(fsdp_model, model_ckpt_path, shard=False)
|
||||
booster.save_optimizer(optimizer, optim_ckpt_path, shard=False)
|
||||
|
||||
full_msd = fsdp_model.state_dict()
|
||||
#full_osd = FSDP.full_optim_state_dict(fsdp_model, optimizer)
|
||||
sharded_osd = optimizer.state_dict()
|
||||
import copy
|
||||
sharded_osd = copy.deepcopy(sharded_osd)
|
||||
|
||||
run_model()
|
||||
|
||||
full_msd_updated = fsdp_model.state_dict()
|
||||
#full_osd_updated = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True)
|
||||
sharded_osd_updated = optimizer.state_dict()
|
||||
|
||||
assert not compare_nested_dict(sharded_osd, sharded_osd_updated)
|
||||
assert not compare_nested_dict(full_msd_updated, full_msd)
|
||||
outputs_first = fsdp_model(inputs)
|
||||
assert criterion(outputs_first) != criterion(outputs)
|
||||
|
||||
booster.load_model(fsdp_model, model_ckpt_path)
|
||||
booster.load_optimizer(optimizer, optim_ckpt_path)
|
||||
|
||||
full_msd_restore = fsdp_model.state_dict()
|
||||
#full_osd_restore = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True)
|
||||
sharded_osd_restore = optimizer.state_dict()
|
||||
|
||||
assert compare_nested_dict(sharded_osd, sharded_osd_restore)
|
||||
assert compare_nested_dict(full_msd_restore, full_msd)
|
||||
outputs_sec = fsdp_model(inputs)
|
||||
assert criterion(outputs_sec) == criterion(outputs)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
# init dist env
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_torch_fsdp_ckpt()
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_torch_fsdp_ckpt():
|
||||
spawn(run_dist, 2)
|
Loading…
Reference in New Issue