[booster] torch fsdp fix ckpt (#3788)

pull/3813/head
wukong1992 2023-05-23 16:58:45 +08:00 committed by GitHub
parent 9265f2d4d7
commit 6b305a99d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 230 additions and 186 deletions

View File

@ -196,7 +196,7 @@ class Booster:
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard)
self.checkpoint_io.save_model(model, checkpoint=checkpoint, shard=shard, size_per_shard=size_per_shard)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""Load optimizer from checkpoint.

View File

@ -1,3 +1,4 @@
from pathlib import Path
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
import torch
@ -5,30 +6,18 @@ import torch.nn as nn
from packaging import version
from torch.distributed import ProcessGroup
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
if version.parse(torch.__version__) >= version.parse('1.12.0'):
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
CPUOffload,
MixedPrecision,
ShardingStrategy,
)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._init_utils import ProcessGroupType
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
FullOptimStateDictConfig,
FullStateDictConfig,
MixedPrecision,
ShardingStrategy,
StateDictType,
)
from torch.distributed.fsdp.wrap import _FSDPPolicy
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
@ -36,7 +25,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
@ -51,102 +40,71 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
super().__init__()
self.coordinator = DistCoordinator()
def __set_model_optim_state(
self,
model,
state_dict_type,
state_dict_config,
optim_state_dict_config,
):
return FSDP.set_state_dict_type(model, state_dict_type, state_dict_config, optim_state_dict_config)
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
checkpoint = utils.load_state_dict(checkpoint)
model.load_state_dict(checkpoint)
def load_sharded_model(self, model: nn.Module, checkpoint: str):
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = utils.load_state_dict(checkpoint)
fsdp_model = optimizer.unwrap_model()
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
optimizer.load_state_dict(sharded_osd)
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
def load_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
def save_sharded_model(self, model: nn.Module, checkpoint: str):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
def save_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
def load_unsharded_model(self, model: nn.Module, checkpoint: str):
"""
Load model from checkpoint with automatic unwrapping.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
full_state_dict = self.load_state_dict(checkpoint)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
full_state_dict = self.load_state_dict(checkpoint)
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True))
full_state_dict = model.state_dict()
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
model.load_state_dict(full_state_dict)
def load_unsharded_optimizer(self, model: nn.Module, optim: Optimizer, checkpoint: str):
"""
Load Optimizer from checkpoint with automatic unwrapping.
"""
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
optim_full_state_dict = self.load_state_dict(checkpoint)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
optim_full_state_dict = self.load_state_dict(checkpoint)
FSDP.full_optim_state_dict_to_load(optim_full_state_dict, model, optim)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
optim.load_state_dict(optim_full_state_dict)
def save_unsharded_model(self, model: nn.Module, checkpoint: str):
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
full_model_state = model.state_dict()
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
model_state_dict = model.state_dict()
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True))
model_state_dict = model.state_dict()
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
self.save_checkpoint(model_state_dict, checkpoint)
def save_unsharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
assert isinstance(optimizer, FSDPOptimizerWrapper)
fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
optim_state_dict = FSDP.full_optim_state_dict(model=model, optim=optimizer)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT,
FullOptimStateDictConfig(rank0_only=True))
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
self.save_checkpoint(optim_state_dict, checkpoint)
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
size_per_shard: int, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
def load_sharded_model(self,
model: nn.Module,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True):
"""
Load model to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
"""
Load optimizer to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
class TorchFSDPModel(ModelWrapper):
@ -156,7 +114,17 @@ class TorchFSDPModel(ModelWrapper):
self.module = FSDP(module, *args, **kwargs)
def unwrap(self):
return self.module.module
return self.module
class FSDPOptimizerWrapper(OptimizerWrapper):
def __init__(self, optimizer: Optimizer, model: nn.Module):
self.model = model
super().__init__(optimizer)
def unwrap_model(self) -> nn.Module:
return self.model
class TorchFSDPPlugin(DPPluginBase):
@ -178,8 +146,7 @@ class TorchFSDPPlugin(DPPluginBase):
See https://pytorch.org/docs/stable/fsdp.html for details.
"""
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
if version.parse(torch.__version__) >= version.parse('1.12.0'):
def __init__(
self,
@ -191,7 +158,6 @@ class TorchFSDPPlugin(DPPluginBase):
mixed_precision: Optional[MixedPrecision] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
device_id: Optional[Union[int, torch.device]] = None,
sync_module_states: bool = False,
):
super().__init__()
@ -203,42 +169,7 @@ class TorchFSDPPlugin(DPPluginBase):
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
device_id=device_id,
sync_module_states=sync_module_states)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
def __init__(
self,
process_group: ProcessGroupType = None,
sharding_strategy: Optional[ShardingStrategy] = None,
cpu_offload: Optional[CPUOffload] = None,
auto_wrap_policy: Optional[Union[Callable, _FSDPPolicy]] = None,
backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE,
mixed_precision: Optional[MixedPrecision] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
device_id: Optional[Union[int, torch.device]] = None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
limit_all_gathers: bool = False,
use_orig_params: bool = False,
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
):
super().__init__()
self.fsdp_kwargs = dict(process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
device_id=device_id,
sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers,
use_orig_params=use_orig_params,
ignored_parameters=ignored_parameters)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
@ -269,14 +200,14 @@ class TorchFSDPPlugin(DPPluginBase):
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
model = model.cuda()
# wrap the model with PyTorch FSDP
model = TorchFSDPModel(model, **self.fsdp_kwargs)
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
if not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)
if not isinstance(optimizer, FSDPOptimizerWrapper):
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
return model, optimizer, criterion, dataloader, lr_scheduler
return fsdp_model, optimizer, criterion, dataloader, lr_scheduler
def control_checkpoint_io(self) -> bool:
return True

View File

@ -1,26 +1,26 @@
from pathlib import Path
import gc
import logging
import os
from functools import reduce
from pathlib import Path
from typing import Iterator, Optional, OrderedDict, Tuple
import torch.nn as nn
from torch.optim import Optimizer
import logging
import os
import gc
from typing import Optional, Iterator, OrderedDict, Tuple
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
has_index_file,
load_state_dict,
save_state_dict,
is_safetensors_available,
shard_checkpoint,
load_shard_state_dict,
load_state_dict_into_model,
get_base_filenames,
get_shard_filename,
get_base_filenames
)
has_index_file,
is_safetensors_available,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
save_state_dict,
shard_checkpoint,
)
__all__ = ['GeneralCheckpointIO']
@ -29,6 +29,7 @@ class GeneralCheckpointIO(CheckpointIO):
"""
Checkpoint IO
"""
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
checkpoint = load_state_dict(checkpoint)
model.load_state_dict(checkpoint, strict=strict)
@ -69,19 +70,23 @@ class GeneralCheckpointIO(CheckpointIO):
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False,
variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False):
"""
def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = False,
variant: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
"""
implement this method as it can be supported by Huggingface model,
save shard model, save model to multiple files
"""
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
# shard checkpoint
state_dict = model.state_dict()
state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
@ -95,21 +100,22 @@ class GeneralCheckpointIO(CheckpointIO):
total_size = total_size + shard_pair[1]
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(
f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
logging.info(f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False,
use_safetensors: bool = False, load_sub_module: bool = True):
def load_sharded_model(self,
model: nn.Module,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True):
"""
load shard model, load model from multiple files
"""
@ -119,7 +125,7 @@ class GeneralCheckpointIO(CheckpointIO):
if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# read checkpoint index file
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
@ -134,10 +140,7 @@ class GeneralCheckpointIO(CheckpointIO):
if strict:
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
if len(remain_keys) > 0:
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys))
error_msgs = 'Missing key(s) in state_dict: {}. '.format(', '.join(
'"{}"'.format(k) for k in missing_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
self.__class__.__name__, "\n\t".join(error_msgs)))

View File

@ -1,10 +1,6 @@
from contextlib import nullcontext
import pytest
import torch
import torch.distributed as dist
from packaging import version
from torch import nn
from torch.optim import SGD
import colossalai
@ -19,6 +15,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
# test baisc fsdp function
def run_fn(model_fn, data_gen_fn, output_transform_fn):
plugin = TorchFSDPPlugin()
booster = Booster(plugin=plugin)

View File

@ -0,0 +1,113 @@
import pytest
import torch
from packaging import version
from torch import nn
from torch.optim import SGD
from torchvision.models import resnet18
from utils import shared_tempdir
import colossalai
from colossalai.booster import Booster
if version.parse(torch.__version__) >= version.parse('1.12.0'):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from colossalai.booster.plugin import TorchFSDPPlugin
from colossalai.testing import rerun_if_address_is_in_use, spawn, check_state_dict_equal
def compare_nested_dict(dict1, dict2):
for key in dict1:
if key in dict2:
if type(dict1[key]) is dict:
assert type(dict2[key]) is dict
diff = compare_nested_dict(dict1[key], dict2[key])
if not diff:
return diff
elif type(dict1[key]) is list:
assert type(dict2[key]) is list
for i, val in enumerate(dict1[key]):
if isinstance(val, torch.Tensor):
if not torch.equal(dict1[key][i], dict2[key][i]):
return False
elif val != dict2[key][i]:
return False
elif type(dict1[key]) is torch.Tensor:
assert type(dict2[key]) is torch.Tensor
if not torch.equal(dict1[key], dict2[key]):
return False
else:
if dict1[key] != dict2[key]:
return False
else:
return False
return True
def check_torch_fsdp_ckpt():
model = resnet18()
plugin = TorchFSDPPlugin()
booster = Booster(plugin=plugin)
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = lambda x: x.mean()
fsdp_model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
inputs = torch.randn(4, 3, 224, 224)
outputs = None
def run_model():
nonlocal outputs
outputs = fsdp_model(inputs)
optimizer.zero_grad()
criterion(outputs).backward()
optimizer.step()
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optim_ckpt_path = f"{tempdir}/optimizer"
run_model()
booster.save_model(fsdp_model, model_ckpt_path, shard=False)
booster.save_optimizer(optimizer, optim_ckpt_path, shard=False)
full_msd = fsdp_model.state_dict()
#full_osd = FSDP.full_optim_state_dict(fsdp_model, optimizer)
sharded_osd = optimizer.state_dict()
import copy
sharded_osd = copy.deepcopy(sharded_osd)
run_model()
full_msd_updated = fsdp_model.state_dict()
#full_osd_updated = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True)
sharded_osd_updated = optimizer.state_dict()
assert not compare_nested_dict(sharded_osd, sharded_osd_updated)
assert not compare_nested_dict(full_msd_updated, full_msd)
outputs_first = fsdp_model(inputs)
assert criterion(outputs_first) != criterion(outputs)
booster.load_model(fsdp_model, model_ckpt_path)
booster.load_optimizer(optimizer, optim_ckpt_path)
full_msd_restore = fsdp_model.state_dict()
#full_osd_restore = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True)
sharded_osd_restore = optimizer.state_dict()
assert compare_nested_dict(sharded_osd, sharded_osd_restore)
assert compare_nested_dict(full_msd_restore, full_msd)
outputs_sec = fsdp_model(inputs)
assert criterion(outputs_sec) == criterion(outputs)
def run_dist(rank, world_size, port):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_torch_fsdp_ckpt()
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher")
@rerun_if_address_is_in_use()
def test_torch_fsdp_ckpt():
spawn(run_dist, 2)