[booster] make optimizer argument optional for boost (#3993)

* feat: make optimizer optional in Booster.boost

* test: skip unet test if diffusers version > 0.10.2
pull/4002/head
Wenhao Chen 2023-06-15 17:38:42 +08:00 committed by GitHub
parent c9cff7e7fa
commit 725af3eeeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 70 additions and 50 deletions

View File

@ -97,10 +97,10 @@ class Booster:
def boost(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
"""
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.

View File

@ -115,10 +115,12 @@ class FP16TorchMixedPrecision(MixedPrecision):
def configure(self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
model = TorchAMPModule(model)
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
if optimizer is not None:
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
if criterion is not None:
criterion = TorchAMPModule(criterion)
return model, optimizer, criterion

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Tuple
from typing import Callable, Optional, Tuple
import torch.nn as nn
from torch.optim import Optimizer
@ -15,7 +15,8 @@ class MixedPrecision(ABC):
@abstractmethod
def configure(self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
# TODO: implement this method
pass

View File

@ -274,11 +274,11 @@ class GeminiPlugin(DPPluginBase):
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper):
# convert model to sync bn
@ -293,8 +293,12 @@ class GeminiPlugin(DPPluginBase):
# wrap the model with Gemini
model = GeminiModel(model, self.gemini_config, self.verbose)
if not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(),
optimizer,
self.zero_optim_config,
self.optim_kwargs,
self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -197,17 +197,21 @@ class LowLevelZeroPlugin(DPPluginBase):
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.stage, self.precision)
if not isinstance(optimizer, OptimizerWrapper):
optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = LowLevelZeroOptimizer(model.unwrap(),
optimizer,
self.zero_optim_config,
self.optim_kwargs,
self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Iterator, List, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch.nn as nn
from torch.optim import Optimizer
@ -38,11 +38,11 @@ class Plugin(ABC):
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# implement this method
pass

View File

@ -138,11 +138,11 @@ class TorchDDPPlugin(DPPluginBase):
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# cast model to cuda
model = model.cuda()
@ -152,7 +152,8 @@ class TorchDDPPlugin(DPPluginBase):
# wrap the model with PyTorch DDP
model = TorchDDPModel(model, **self.ddp_kwargs)
if not isinstance(optimizer, OptimizerWrapper):
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)
return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -195,23 +195,24 @@ class TorchFSDPPlugin(DPPluginBase):
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# wrap the model with PyTorch FSDP
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
if len(optimizer.param_groups) > 1:
warnings.warn(
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
)
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
if optimizer is not None:
if len(optimizer.param_groups) > 1:
warnings.warn(
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
)
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
if not isinstance(optimizer, FSDPOptimizerWrapper):
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
if not isinstance(optimizer, FSDPOptimizerWrapper):
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
return fsdp_model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -4,12 +4,15 @@ import pytest
import torch
try:
from diffusers import UNet2DModel
MODELS = [UNet2DModel]
import diffusers
MODELS = [diffusers.UNet2DModel]
HAS_REPO = True
from packaging import version
SKIP_UNET_TEST = version.parse(diffusers.__version__) > version.parse("0.10.2")
except:
MODELS = []
HAS_REPO = False
SKIP_UNET_TEST = False
from test_autochunk_diffuser_utils import run_test
@ -32,6 +35,10 @@ def get_data(shape: tuple) -> Tuple[List, List]:
return meta_args, concrete_args
@pytest.mark.skipif(
SKIP_UNET_TEST,
reason="diffusers version > 0.10.2",
)
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",