mirror of https://github.com/hpcaitech/ColossalAI
wukong1992
2 years ago
committed by
GitHub
4 changed files with 358 additions and 2 deletions
@ -0,0 +1,285 @@ |
|||||||
|
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union |
||||||
|
|
||||||
|
import torch |
||||||
|
import torch.nn as nn |
||||||
|
from packaging import version |
||||||
|
from torch.distributed import ProcessGroup |
||||||
|
|
||||||
|
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( |
||||||
|
torch.__version__) < version.parse('2.0.0'): |
||||||
|
from torch.distributed.fsdp import FullStateDictConfig |
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
||||||
|
from torch.distributed.fsdp import StateDictType |
||||||
|
from torch.distributed.fsdp.fully_sharded_data_parallel import ( |
||||||
|
BackwardPrefetch, |
||||||
|
CPUOffload, |
||||||
|
MixedPrecision, |
||||||
|
ShardingStrategy, |
||||||
|
) |
||||||
|
elif version.parse(torch.__version__) >= version.parse('2.0.0'): |
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
||||||
|
from torch.distributed.fsdp._init_utils import ProcessGroupType |
||||||
|
from torch.distributed.fsdp.api import ( |
||||||
|
BackwardPrefetch, |
||||||
|
CPUOffload, |
||||||
|
FullOptimStateDictConfig, |
||||||
|
FullStateDictConfig, |
||||||
|
MixedPrecision, |
||||||
|
ShardingStrategy, |
||||||
|
StateDictType, |
||||||
|
) |
||||||
|
from torch.distributed.fsdp.wrap import _FSDPPolicy |
||||||
|
else: |
||||||
|
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") |
||||||
|
|
||||||
|
from torch.optim import Optimizer |
||||||
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler |
||||||
|
from torch.utils.data import DataLoader |
||||||
|
|
||||||
|
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO |
||||||
|
from colossalai.cluster import DistCoordinator |
||||||
|
from colossalai.interface import ModelWrapper, OptimizerWrapper |
||||||
|
|
||||||
|
from .dp_plugin_base import DPPluginBase |
||||||
|
|
||||||
|
__all__ = ['TorchFSDPPlugin'] |
||||||
|
|
||||||
|
|
||||||
|
class TorchFSDPCheckpointIO(GeneralCheckpointIO): |
||||||
|
|
||||||
|
def __init__(self) -> None: |
||||||
|
super().__init__() |
||||||
|
self.coordinator = DistCoordinator() |
||||||
|
|
||||||
|
def __set_model_optim_state( |
||||||
|
self, |
||||||
|
model, |
||||||
|
state_dict_type, |
||||||
|
state_dict_config, |
||||||
|
optim_state_dict_config, |
||||||
|
): |
||||||
|
return FSDP.set_state_dict_type(model, state_dict_type, state_dict_config, optim_state_dict_config) |
||||||
|
|
||||||
|
def load_sharded_model(self, model: nn.Module, checkpoint: str): |
||||||
|
|
||||||
|
# TODO(jishaomin): implement this method as it can be supported by Huggingface model |
||||||
|
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.") |
||||||
|
|
||||||
|
def load_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str): |
||||||
|
|
||||||
|
# TODO(jishaomin): implement this method as it can be supported by Huggingface model |
||||||
|
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.") |
||||||
|
|
||||||
|
def save_sharded_model(self, model: nn.Module, checkpoint: str): |
||||||
|
|
||||||
|
# TODO(jishaomin): implement this method as it can be supported by Huggingface model |
||||||
|
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.") |
||||||
|
|
||||||
|
def save_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str): |
||||||
|
|
||||||
|
# TODO(jishaomin): implement this method as it can be supported by Huggingface model |
||||||
|
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.") |
||||||
|
|
||||||
|
def load_unsharded_model(self, model: nn.Module, checkpoint: str): |
||||||
|
""" |
||||||
|
Load model from checkpoint with automatic unwrapping. |
||||||
|
""" |
||||||
|
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap |
||||||
|
|
||||||
|
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( |
||||||
|
torch.__version__) < version.parse('2.0.0'): |
||||||
|
full_state_dict = self.load_state_dict(checkpoint) |
||||||
|
elif version.parse(torch.__version__) >= version.parse('2.0.0'): |
||||||
|
full_state_dict = self.load_state_dict(checkpoint) |
||||||
|
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True)) |
||||||
|
full_state_dict = model.state_dict() |
||||||
|
else: |
||||||
|
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") |
||||||
|
|
||||||
|
model.load_state_dict(full_state_dict) |
||||||
|
|
||||||
|
def load_unsharded_optimizer(self, model: nn.Module, optim: Optimizer, checkpoint: str): |
||||||
|
""" |
||||||
|
Load Optimizer from checkpoint with automatic unwrapping. |
||||||
|
""" |
||||||
|
|
||||||
|
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( |
||||||
|
torch.__version__) < version.parse('2.0.0'): |
||||||
|
optim_full_state_dict = self.load_state_dict(checkpoint) |
||||||
|
elif version.parse(torch.__version__) >= version.parse('2.0.0'): |
||||||
|
optim_full_state_dict = self.load_state_dict(checkpoint) |
||||||
|
FSDP.full_optim_state_dict_to_load(optim_full_state_dict, model, optim) |
||||||
|
else: |
||||||
|
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") |
||||||
|
|
||||||
|
optim.load_state_dict(optim_full_state_dict) |
||||||
|
|
||||||
|
def save_unsharded_model(self, model: nn.Module, checkpoint: str): |
||||||
|
""" |
||||||
|
Save model to checkpoint but only on master process. |
||||||
|
""" |
||||||
|
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap |
||||||
|
|
||||||
|
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( |
||||||
|
torch.__version__) < version.parse('2.0.0'): |
||||||
|
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) |
||||||
|
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): |
||||||
|
model_state_dict = model.state_dict() |
||||||
|
elif version.parse(torch.__version__) >= version.parse('2.0.0'): |
||||||
|
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True)) |
||||||
|
model_state_dict = model.state_dict() |
||||||
|
else: |
||||||
|
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") |
||||||
|
self.save_checkpoint(model_state_dict, checkpoint) |
||||||
|
|
||||||
|
def save_unsharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str): |
||||||
|
""" |
||||||
|
Save optimizer to checkpoint but only on master process. |
||||||
|
""" |
||||||
|
|
||||||
|
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( |
||||||
|
torch.__version__) < version.parse('2.0.0'): |
||||||
|
optim_state_dict = FSDP.full_optim_state_dict(model=model, optim=optimizer) |
||||||
|
elif version.parse(torch.__version__) >= version.parse('2.0.0'): |
||||||
|
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, |
||||||
|
FullOptimStateDictConfig(rank0_only=True)) |
||||||
|
optim_state_dict = FSDP.optim_state_dict(model, optimizer) |
||||||
|
else: |
||||||
|
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") |
||||||
|
self.save_checkpoint(optim_state_dict, checkpoint) |
||||||
|
|
||||||
|
|
||||||
|
class TorchFSDPModel(ModelWrapper): |
||||||
|
|
||||||
|
def __init__(self, module: nn.Module, *args, **kwargs) -> None: |
||||||
|
super().__init__(module) |
||||||
|
self.module = FSDP(module, *args, **kwargs) |
||||||
|
|
||||||
|
def unwrap(self): |
||||||
|
return self.module.module |
||||||
|
|
||||||
|
|
||||||
|
class TorchFSDPPlugin(DPPluginBase): |
||||||
|
""" |
||||||
|
Plugin for PyTorch FSDP. |
||||||
|
|
||||||
|
Example: |
||||||
|
>>> from colossalai.booster import Booster |
||||||
|
>>> from colossalai.booster.plugin import TorchFSDPPlugin |
||||||
|
>>> |
||||||
|
>>> model, train_dataset, optimizer, criterion = ... |
||||||
|
>>> plugin = TorchFSDPPlugin() |
||||||
|
|
||||||
|
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) |
||||||
|
>>> booster = Booster(plugin=plugin) |
||||||
|
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) |
||||||
|
|
||||||
|
Args: |
||||||
|
See https://pytorch.org/docs/stable/fsdp.html for details. |
||||||
|
""" |
||||||
|
|
||||||
|
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( |
||||||
|
torch.__version__) < version.parse('2.0.0'): |
||||||
|
|
||||||
|
def __init__( |
||||||
|
self, |
||||||
|
process_group: Optional[ProcessGroup] = None, |
||||||
|
sharding_strategy: Optional[ShardingStrategy] = None, |
||||||
|
cpu_offload: Optional[CPUOffload] = None, |
||||||
|
auto_wrap_policy: Optional[Callable] = None, |
||||||
|
backward_prefetch: Optional[BackwardPrefetch] = None, |
||||||
|
mixed_precision: Optional[MixedPrecision] = None, |
||||||
|
ignored_modules: Optional[Iterable[torch.nn.Module]] = None, |
||||||
|
param_init_fn: Optional[Callable[[nn.Module], None]] = None, |
||||||
|
device_id: Optional[Union[int, torch.device]] = None, |
||||||
|
sync_module_states: bool = False, |
||||||
|
): |
||||||
|
super().__init__() |
||||||
|
self.fsdp_kwargs = dict(process_group=process_group, |
||||||
|
sharding_strategy=sharding_strategy, |
||||||
|
cpu_offload=cpu_offload, |
||||||
|
auto_wrap_policy=auto_wrap_policy, |
||||||
|
backward_prefetch=backward_prefetch, |
||||||
|
mixed_precision=mixed_precision, |
||||||
|
ignored_modules=ignored_modules, |
||||||
|
param_init_fn=param_init_fn, |
||||||
|
device_id=device_id, |
||||||
|
sync_module_states=sync_module_states) |
||||||
|
elif version.parse(torch.__version__) >= version.parse('2.0.0'): |
||||||
|
|
||||||
|
def __init__( |
||||||
|
self, |
||||||
|
process_group: ProcessGroupType = None, |
||||||
|
sharding_strategy: Optional[ShardingStrategy] = None, |
||||||
|
cpu_offload: Optional[CPUOffload] = None, |
||||||
|
auto_wrap_policy: Optional[Union[Callable, _FSDPPolicy]] = None, |
||||||
|
backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE, |
||||||
|
mixed_precision: Optional[MixedPrecision] = None, |
||||||
|
ignored_modules: Optional[Iterable[torch.nn.Module]] = None, |
||||||
|
param_init_fn: Optional[Callable[[nn.Module], None]] = None, |
||||||
|
device_id: Optional[Union[int, torch.device]] = None, |
||||||
|
sync_module_states: bool = False, |
||||||
|
forward_prefetch: bool = False, |
||||||
|
limit_all_gathers: bool = False, |
||||||
|
use_orig_params: bool = False, |
||||||
|
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None, |
||||||
|
): |
||||||
|
super().__init__() |
||||||
|
self.fsdp_kwargs = dict(process_group=process_group, |
||||||
|
sharding_strategy=sharding_strategy, |
||||||
|
cpu_offload=cpu_offload, |
||||||
|
auto_wrap_policy=auto_wrap_policy, |
||||||
|
backward_prefetch=backward_prefetch, |
||||||
|
mixed_precision=mixed_precision, |
||||||
|
ignored_modules=ignored_modules, |
||||||
|
param_init_fn=param_init_fn, |
||||||
|
device_id=device_id, |
||||||
|
sync_module_states=sync_module_states, |
||||||
|
forward_prefetch=forward_prefetch, |
||||||
|
limit_all_gathers=limit_all_gathers, |
||||||
|
use_orig_params=use_orig_params, |
||||||
|
ignored_parameters=ignored_parameters) |
||||||
|
else: |
||||||
|
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") |
||||||
|
|
||||||
|
def support_no_sync(self) -> bool: |
||||||
|
False |
||||||
|
|
||||||
|
def no_sync(self, model: nn.Module) -> Iterator[None]: |
||||||
|
raise NotImplementedError("Torch fsdp no_sync func not supported yet.") |
||||||
|
|
||||||
|
def control_precision(self) -> bool: |
||||||
|
return True |
||||||
|
|
||||||
|
def supported_precisions(self) -> List[str]: |
||||||
|
return ['fp16', 'bf16'] |
||||||
|
|
||||||
|
def control_device(self) -> bool: |
||||||
|
return True |
||||||
|
|
||||||
|
def supported_devices(self) -> List[str]: |
||||||
|
return ['cuda'] |
||||||
|
|
||||||
|
def configure( |
||||||
|
self, |
||||||
|
model: nn.Module, |
||||||
|
optimizer: Optimizer, |
||||||
|
criterion: Callable = None, |
||||||
|
dataloader: DataLoader = None, |
||||||
|
lr_scheduler: LRScheduler = None, |
||||||
|
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: |
||||||
|
|
||||||
|
model = model.cuda() |
||||||
|
# wrap the model with PyTorch FSDP |
||||||
|
model = TorchFSDPModel(model, **self.fsdp_kwargs) |
||||||
|
|
||||||
|
if not isinstance(optimizer, OptimizerWrapper): |
||||||
|
optimizer = OptimizerWrapper(optimizer) |
||||||
|
|
||||||
|
return model, optimizer, criterion, dataloader, lr_scheduler |
||||||
|
|
||||||
|
def control_checkpoint_io(self) -> bool: |
||||||
|
return True |
||||||
|
|
||||||
|
def get_checkpoint_io(self) -> CheckpointIO: |
||||||
|
return TorchFSDPCheckpointIO() |
@ -0,0 +1,64 @@ |
|||||||
|
from contextlib import nullcontext |
||||||
|
|
||||||
|
import pytest |
||||||
|
import torch |
||||||
|
import torch.distributed as dist |
||||||
|
from packaging import version |
||||||
|
from torch import nn |
||||||
|
from torch.optim import SGD |
||||||
|
|
||||||
|
import colossalai |
||||||
|
from colossalai.booster import Booster |
||||||
|
|
||||||
|
if version.parse(torch.__version__) >= version.parse('1.12.0'): |
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
||||||
|
from colossalai.booster.plugin import TorchFSDPPlugin |
||||||
|
|
||||||
|
from colossalai.interface import OptimizerWrapper |
||||||
|
from colossalai.testing import rerun_if_address_is_in_use, spawn |
||||||
|
from tests.kit.model_zoo import model_zoo |
||||||
|
|
||||||
|
|
||||||
|
def run_fn(model_fn, data_gen_fn, output_transform_fn): |
||||||
|
plugin = TorchFSDPPlugin() |
||||||
|
booster = Booster(plugin=plugin) |
||||||
|
model = model_fn() |
||||||
|
optimizer = SGD(model.parameters(), lr=1e-3) |
||||||
|
criterion = lambda x: x.mean() |
||||||
|
data = data_gen_fn() |
||||||
|
|
||||||
|
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} |
||||||
|
|
||||||
|
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) |
||||||
|
|
||||||
|
assert isinstance(model.module, FSDP) |
||||||
|
assert isinstance(optimizer, OptimizerWrapper) |
||||||
|
|
||||||
|
output = model(**data) |
||||||
|
output = output_transform_fn(output) |
||||||
|
output_key = list(output.keys())[0] |
||||||
|
loss = criterion(output[output_key]) |
||||||
|
|
||||||
|
booster.backward(loss, optimizer) |
||||||
|
optimizer.clip_grad_by_norm(1.0) |
||||||
|
optimizer.step() |
||||||
|
|
||||||
|
|
||||||
|
def check_torch_fsdp_plugin(): |
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): |
||||||
|
if 'diffusers' in name: |
||||||
|
continue |
||||||
|
run_fn(model_fn, data_gen_fn, output_transform_fn) |
||||||
|
torch.cuda.empty_cache() |
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port): |
||||||
|
# init dist env |
||||||
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') |
||||||
|
check_torch_fsdp_plugin() |
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher") |
||||||
|
@rerun_if_address_is_in_use() |
||||||
|
def test_torch_fsdp_plugin(): |
||||||
|
spawn(run_dist, 2) |
Loading…
Reference in new issue