Browse Source

[booster] support torch fsdp plugin in booster (#3697)

Co-authored-by: 纪少敏 <jishaomin@jishaomindeMBP.lan>
pull/3725/head
wukong1992 2 years ago committed by GitHub
parent
commit
b37797ed3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 7
      colossalai/booster/plugin/__init__.py
  2. 285
      colossalai/booster/plugin/torch_fsdp_plugin.py
  3. 4
      colossalai/testing/utils.py
  4. 64
      tests/test_booster/test_plugin/test_torch_fsdp_plugin.py

7
colossalai/booster/plugin/__init__.py

@ -4,3 +4,10 @@ from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin']
import torch
from packaging import version
if version.parse(torch.__version__) >= version.parse('1.12.0'):
from .torch_fsdp_plugin import TorchFSDPPlugin
__all__.append('TorchFSDPPlugin')

285
colossalai/booster/plugin/torch_fsdp_plugin.py

@ -0,0 +1,285 @@
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from packaging import version
from torch.distributed import ProcessGroup
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
CPUOffload,
MixedPrecision,
ShardingStrategy,
)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._init_utils import ProcessGroupType
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
FullOptimStateDictConfig,
FullStateDictConfig,
MixedPrecision,
ShardingStrategy,
StateDictType,
)
from torch.distributed.fsdp.wrap import _FSDPPolicy
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from .dp_plugin_base import DPPluginBase
__all__ = ['TorchFSDPPlugin']
class TorchFSDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
def __set_model_optim_state(
self,
model,
state_dict_type,
state_dict_config,
optim_state_dict_config,
):
return FSDP.set_state_dict_type(model, state_dict_type, state_dict_config, optim_state_dict_config)
def load_sharded_model(self, model: nn.Module, checkpoint: str):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
def load_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
def save_sharded_model(self, model: nn.Module, checkpoint: str):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
def save_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
def load_unsharded_model(self, model: nn.Module, checkpoint: str):
"""
Load model from checkpoint with automatic unwrapping.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
full_state_dict = self.load_state_dict(checkpoint)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
full_state_dict = self.load_state_dict(checkpoint)
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True))
full_state_dict = model.state_dict()
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
model.load_state_dict(full_state_dict)
def load_unsharded_optimizer(self, model: nn.Module, optim: Optimizer, checkpoint: str):
"""
Load Optimizer from checkpoint with automatic unwrapping.
"""
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
optim_full_state_dict = self.load_state_dict(checkpoint)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
optim_full_state_dict = self.load_state_dict(checkpoint)
FSDP.full_optim_state_dict_to_load(optim_full_state_dict, model, optim)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
optim.load_state_dict(optim_full_state_dict)
def save_unsharded_model(self, model: nn.Module, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
model_state_dict = model.state_dict()
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True))
model_state_dict = model.state_dict()
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
self.save_checkpoint(model_state_dict, checkpoint)
def save_unsharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
"""
Save optimizer to checkpoint but only on master process.
"""
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
optim_state_dict = FSDP.full_optim_state_dict(model=model, optim=optimizer)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT,
FullOptimStateDictConfig(rank0_only=True))
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
self.save_checkpoint(optim_state_dict, checkpoint)
class TorchFSDPModel(ModelWrapper):
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = FSDP(module, *args, **kwargs)
def unwrap(self):
return self.module.module
class TorchFSDPPlugin(DPPluginBase):
"""
Plugin for PyTorch FSDP.
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import TorchFSDPPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = TorchFSDPPlugin()
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
Args:
See https://pytorch.org/docs/stable/fsdp.html for details.
"""
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
def __init__(
self,
process_group: Optional[ProcessGroup] = None,
sharding_strategy: Optional[ShardingStrategy] = None,
cpu_offload: Optional[CPUOffload] = None,
auto_wrap_policy: Optional[Callable] = None,
backward_prefetch: Optional[BackwardPrefetch] = None,
mixed_precision: Optional[MixedPrecision] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
device_id: Optional[Union[int, torch.device]] = None,
sync_module_states: bool = False,
):
super().__init__()
self.fsdp_kwargs = dict(process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
device_id=device_id,
sync_module_states=sync_module_states)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
def __init__(
self,
process_group: ProcessGroupType = None,
sharding_strategy: Optional[ShardingStrategy] = None,
cpu_offload: Optional[CPUOffload] = None,
auto_wrap_policy: Optional[Union[Callable, _FSDPPolicy]] = None,
backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE,
mixed_precision: Optional[MixedPrecision] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
device_id: Optional[Union[int, torch.device]] = None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
limit_all_gathers: bool = False,
use_orig_params: bool = False,
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
):
super().__init__()
self.fsdp_kwargs = dict(process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
device_id=device_id,
sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers,
use_orig_params=use_orig_params,
ignored_parameters=ignored_parameters)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
def support_no_sync(self) -> bool:
False
def no_sync(self, model: nn.Module) -> Iterator[None]:
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
def control_precision(self) -> bool:
return True
def supported_precisions(self) -> List[str]:
return ['fp16', 'bf16']
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
return ['cuda']
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
model = model.cuda()
# wrap the model with PyTorch FSDP
model = TorchFSDPModel(model, **self.fsdp_kwargs)
if not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)
return model, optimizer, criterion, dataloader, lr_scheduler
def control_checkpoint_io(self) -> bool:
return True
def get_checkpoint_io(self) -> CheckpointIO:
return TorchFSDPCheckpointIO()

4
colossalai/testing/utils.py

@ -167,10 +167,10 @@ def rerun_if_address_is_in_use():
"""
# check version
torch_version = version.parse(torch.__version__)
assert torch_version.major == 1
assert torch_version.major >= 1
# only torch >= 1.8 has ProcessRaisedException
if torch_version.minor >= 8:
if torch_version >= version.parse("1.8.0"):
exception = torch.multiprocessing.ProcessRaisedException
else:
exception = Exception

64
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py

@ -0,0 +1,64 @@
from contextlib import nullcontext
import pytest
import torch
import torch.distributed as dist
from packaging import version
from torch import nn
from torch.optim import SGD
import colossalai
from colossalai.booster import Booster
if version.parse(torch.__version__) >= version.parse('1.12.0'):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from colossalai.booster.plugin import TorchFSDPPlugin
from colossalai.interface import OptimizerWrapper
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
def run_fn(model_fn, data_gen_fn, output_transform_fn):
plugin = TorchFSDPPlugin()
booster = Booster(plugin=plugin)
model = model_fn()
optimizer = SGD(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
assert isinstance(model.module, FSDP)
assert isinstance(optimizer, OptimizerWrapper)
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
def check_torch_fsdp_plugin():
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
if 'diffusers' in name:
continue
run_fn(model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache()
def run_dist(rank, world_size, port):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_torch_fsdp_plugin()
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher")
@rerun_if_address_is_in_use()
def test_torch_fsdp_plugin():
spawn(run_dist, 2)
Loading…
Cancel
Save