mirror of https://github.com/hpcaitech/ColossalAI
[booster] support torch fsdp plugin in booster (#3697)
Co-authored-by: 纪少敏 <jishaomin@jishaomindeMBP.lan>pull/3725/head
parent
ad6460cf2c
commit
b37797ed3d
|
@ -4,3 +4,10 @@ from .plugin_base import Plugin
|
|||
from .torch_ddp_plugin import TorchDDPPlugin
|
||||
|
||||
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin']
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
from .torch_fsdp_plugin import TorchFSDPPlugin
|
||||
__all__.append('TorchFSDPPlugin')
|
||||
|
|
|
@ -0,0 +1,285 @@
|
|||
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||
torch.__version__) < version.parse('2.0.0'):
|
||||
from torch.distributed.fsdp import FullStateDictConfig
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import StateDictType
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
BackwardPrefetch,
|
||||
CPUOffload,
|
||||
MixedPrecision,
|
||||
ShardingStrategy,
|
||||
)
|
||||
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp._init_utils import ProcessGroupType
|
||||
from torch.distributed.fsdp.api import (
|
||||
BackwardPrefetch,
|
||||
CPUOffload,
|
||||
FullOptimStateDictConfig,
|
||||
FullStateDictConfig,
|
||||
MixedPrecision,
|
||||
ShardingStrategy,
|
||||
StateDictType,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import _FSDPPolicy
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
|
||||
__all__ = ['TorchFSDPPlugin']
|
||||
|
||||
|
||||
class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.coordinator = DistCoordinator()
|
||||
|
||||
def __set_model_optim_state(
|
||||
self,
|
||||
model,
|
||||
state_dict_type,
|
||||
state_dict_config,
|
||||
optim_state_dict_config,
|
||||
):
|
||||
return FSDP.set_state_dict_type(model, state_dict_type, state_dict_config, optim_state_dict_config)
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint: str):
|
||||
|
||||
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
|
||||
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
|
||||
|
||||
def load_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
|
||||
|
||||
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
|
||||
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
|
||||
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str):
|
||||
|
||||
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
|
||||
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
|
||||
|
||||
def save_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
|
||||
|
||||
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
|
||||
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str):
|
||||
"""
|
||||
Load model from checkpoint with automatic unwrapping.
|
||||
"""
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||
torch.__version__) < version.parse('2.0.0'):
|
||||
full_state_dict = self.load_state_dict(checkpoint)
|
||||
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
full_state_dict = self.load_state_dict(checkpoint)
|
||||
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True))
|
||||
full_state_dict = model.state_dict()
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
|
||||
model.load_state_dict(full_state_dict)
|
||||
|
||||
def load_unsharded_optimizer(self, model: nn.Module, optim: Optimizer, checkpoint: str):
|
||||
"""
|
||||
Load Optimizer from checkpoint with automatic unwrapping.
|
||||
"""
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||
torch.__version__) < version.parse('2.0.0'):
|
||||
optim_full_state_dict = self.load_state_dict(checkpoint)
|
||||
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
optim_full_state_dict = self.load_state_dict(checkpoint)
|
||||
FSDP.full_optim_state_dict_to_load(optim_full_state_dict, model, optim)
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
|
||||
optim.load_state_dict(optim_full_state_dict)
|
||||
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||
torch.__version__) < version.parse('2.0.0'):
|
||||
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
|
||||
model_state_dict = model.state_dict()
|
||||
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True))
|
||||
model_state_dict = model.state_dict()
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
self.save_checkpoint(model_state_dict, checkpoint)
|
||||
|
||||
def save_unsharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||
torch.__version__) < version.parse('2.0.0'):
|
||||
optim_state_dict = FSDP.full_optim_state_dict(model=model, optim=optimizer)
|
||||
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT,
|
||||
FullOptimStateDictConfig(rank0_only=True))
|
||||
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
self.save_checkpoint(optim_state_dict, checkpoint)
|
||||
|
||||
|
||||
class TorchFSDPModel(ModelWrapper):
|
||||
|
||||
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
|
||||
super().__init__(module)
|
||||
self.module = FSDP(module, *args, **kwargs)
|
||||
|
||||
def unwrap(self):
|
||||
return self.module.module
|
||||
|
||||
|
||||
class TorchFSDPPlugin(DPPluginBase):
|
||||
"""
|
||||
Plugin for PyTorch FSDP.
|
||||
|
||||
Example:
|
||||
>>> from colossalai.booster import Booster
|
||||
>>> from colossalai.booster.plugin import TorchFSDPPlugin
|
||||
>>>
|
||||
>>> model, train_dataset, optimizer, criterion = ...
|
||||
>>> plugin = TorchFSDPPlugin()
|
||||
|
||||
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
|
||||
>>> booster = Booster(plugin=plugin)
|
||||
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
||||
|
||||
Args:
|
||||
See https://pytorch.org/docs/stable/fsdp.html for details.
|
||||
"""
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||
torch.__version__) < version.parse('2.0.0'):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
sharding_strategy: Optional[ShardingStrategy] = None,
|
||||
cpu_offload: Optional[CPUOffload] = None,
|
||||
auto_wrap_policy: Optional[Callable] = None,
|
||||
backward_prefetch: Optional[BackwardPrefetch] = None,
|
||||
mixed_precision: Optional[MixedPrecision] = None,
|
||||
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
|
||||
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
|
||||
device_id: Optional[Union[int, torch.device]] = None,
|
||||
sync_module_states: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.fsdp_kwargs = dict(process_group=process_group,
|
||||
sharding_strategy=sharding_strategy,
|
||||
cpu_offload=cpu_offload,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
backward_prefetch=backward_prefetch,
|
||||
mixed_precision=mixed_precision,
|
||||
ignored_modules=ignored_modules,
|
||||
param_init_fn=param_init_fn,
|
||||
device_id=device_id,
|
||||
sync_module_states=sync_module_states)
|
||||
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_group: ProcessGroupType = None,
|
||||
sharding_strategy: Optional[ShardingStrategy] = None,
|
||||
cpu_offload: Optional[CPUOffload] = None,
|
||||
auto_wrap_policy: Optional[Union[Callable, _FSDPPolicy]] = None,
|
||||
backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE,
|
||||
mixed_precision: Optional[MixedPrecision] = None,
|
||||
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
|
||||
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
|
||||
device_id: Optional[Union[int, torch.device]] = None,
|
||||
sync_module_states: bool = False,
|
||||
forward_prefetch: bool = False,
|
||||
limit_all_gathers: bool = False,
|
||||
use_orig_params: bool = False,
|
||||
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.fsdp_kwargs = dict(process_group=process_group,
|
||||
sharding_strategy=sharding_strategy,
|
||||
cpu_offload=cpu_offload,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
backward_prefetch=backward_prefetch,
|
||||
mixed_precision=mixed_precision,
|
||||
ignored_modules=ignored_modules,
|
||||
param_init_fn=param_init_fn,
|
||||
device_id=device_id,
|
||||
sync_module_states=sync_module_states,
|
||||
forward_prefetch=forward_prefetch,
|
||||
limit_all_gathers=limit_all_gathers,
|
||||
use_orig_params=use_orig_params,
|
||||
ignored_parameters=ignored_parameters)
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
|
||||
def support_no_sync(self) -> bool:
|
||||
False
|
||||
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
|
||||
|
||||
def control_precision(self) -> bool:
|
||||
return True
|
||||
|
||||
def supported_precisions(self) -> List[str]:
|
||||
return ['fp16', 'bf16']
|
||||
|
||||
def control_device(self) -> bool:
|
||||
return True
|
||||
|
||||
def supported_devices(self) -> List[str]:
|
||||
return ['cuda']
|
||||
|
||||
def configure(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None,
|
||||
dataloader: DataLoader = None,
|
||||
lr_scheduler: LRScheduler = None,
|
||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||
|
||||
model = model.cuda()
|
||||
# wrap the model with PyTorch FSDP
|
||||
model = TorchFSDPModel(model, **self.fsdp_kwargs)
|
||||
|
||||
if not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = OptimizerWrapper(optimizer)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return TorchFSDPCheckpointIO()
|
|
@ -167,10 +167,10 @@ def rerun_if_address_is_in_use():
|
|||
"""
|
||||
# check version
|
||||
torch_version = version.parse(torch.__version__)
|
||||
assert torch_version.major == 1
|
||||
assert torch_version.major >= 1
|
||||
|
||||
# only torch >= 1.8 has ProcessRaisedException
|
||||
if torch_version.minor >= 8:
|
||||
if torch_version >= version.parse("1.8.0"):
|
||||
exception = torch.multiprocessing.ProcessRaisedException
|
||||
else:
|
||||
exception = Exception
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.optim import SGD
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from colossalai.booster.plugin import TorchFSDPPlugin
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
||||
plugin = TorchFSDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
model = model_fn()
|
||||
optimizer = SGD(model.parameters(), lr=1e-3)
|
||||
criterion = lambda x: x.mean()
|
||||
data = data_gen_fn()
|
||||
|
||||
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
|
||||
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
assert isinstance(model.module, FSDP)
|
||||
assert isinstance(optimizer, OptimizerWrapper)
|
||||
|
||||
output = model(**data)
|
||||
output = output_transform_fn(output)
|
||||
output_key = list(output.keys())[0]
|
||||
loss = criterion(output[output_key])
|
||||
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.clip_grad_by_norm(1.0)
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def check_torch_fsdp_plugin():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
|
||||
if 'diffusers' in name:
|
||||
continue
|
||||
run_fn(model_fn, data_gen_fn, output_transform_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
# init dist env
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_torch_fsdp_plugin()
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_torch_fsdp_plugin():
|
||||
spawn(run_dist, 2)
|
Loading…
Reference in New Issue