[booster] make optimizer argument optional for boost (#3993)

* feat: make optimizer optional in Booster.boost

* test: skip unet test if diffusers version > 0.10.2
pull/4002/head
Wenhao Chen 2023-06-15 17:38:42 +08:00 committed by GitHub
parent c9cff7e7fa
commit 725af3eeeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 70 additions and 50 deletions

View File

@ -97,10 +97,10 @@ class Booster:
def boost( def boost(
self, self,
model: nn.Module, model: nn.Module,
optimizer: Optimizer, optimizer: Optional[Optimizer] = None,
criterion: Callable = None, criterion: Optional[Callable] = None,
dataloader: DataLoader = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: LRScheduler = None, lr_scheduler: Optional[LRScheduler] = None,
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: ) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
""" """
Boost the model, optimizer, criterion, lr_scheduler, and dataloader. Boost the model, optimizer, criterion, lr_scheduler, and dataloader.

View File

@ -115,10 +115,12 @@ class FP16TorchMixedPrecision(MixedPrecision):
def configure(self, def configure(self,
model: nn.Module, model: nn.Module,
optimizer: Optimizer, optimizer: Optional[Optimizer] = None,
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]: criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
model = TorchAMPModule(model) model = TorchAMPModule(model)
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs) if optimizer is not None:
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
if criterion is not None: if criterion is not None:
criterion = TorchAMPModule(criterion) criterion = TorchAMPModule(criterion)
return model, optimizer, criterion return model, optimizer, criterion

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, Tuple from typing import Callable, Optional, Tuple
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
@ -15,7 +15,8 @@ class MixedPrecision(ABC):
@abstractmethod @abstractmethod
def configure(self, def configure(self,
model: nn.Module, model: nn.Module,
optimizer: Optimizer, optimizer: Optional[Optimizer] = None,
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]: criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
# TODO: implement this method # TODO: implement this method
pass pass

View File

@ -274,11 +274,11 @@ class GeminiPlugin(DPPluginBase):
def configure( def configure(
self, self,
model: nn.Module, model: nn.Module,
optimizer: Optimizer, optimizer: Optional[Optimizer] = None,
criterion: Callable = None, criterion: Optional[Callable] = None,
dataloader: DataLoader = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: LRScheduler = None, lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
# convert model to sync bn # convert model to sync bn
@ -293,8 +293,12 @@ class GeminiPlugin(DPPluginBase):
# wrap the model with Gemini # wrap the model with Gemini
model = GeminiModel(model, self.gemini_config, self.verbose) model = GeminiModel(model, self.gemini_config, self.verbose)
if not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and \
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(),
optimizer,
self.zero_optim_config,
self.optim_kwargs,
self.verbose) self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -197,17 +197,21 @@ class LowLevelZeroPlugin(DPPluginBase):
def configure( def configure(
self, self,
model: nn.Module, model: nn.Module,
optimizer: Optimizer, optimizer: Optional[Optimizer] = None,
criterion: Callable = None, criterion: Optional[Callable] = None,
dataloader: DataLoader = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: LRScheduler = None, lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.stage, self.precision) model = LowLevelZeroModel(model, self.stage, self.precision)
if not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and \
optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, not isinstance(optimizer, OptimizerWrapper):
optimizer = LowLevelZeroOptimizer(model.unwrap(),
optimizer,
self.zero_optim_config,
self.optim_kwargs,
self.verbose) self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, Iterator, List, Tuple, Union from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
@ -38,11 +38,11 @@ class Plugin(ABC):
def configure( def configure(
self, self,
model: nn.Module, model: nn.Module,
optimizer: Optimizer, optimizer: Optional[Optimizer] = None,
criterion: Callable = None, criterion: Optional[Callable] = None,
dataloader: DataLoader = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: LRScheduler = None, lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# implement this method # implement this method
pass pass

View File

@ -138,11 +138,11 @@ class TorchDDPPlugin(DPPluginBase):
def configure( def configure(
self, self,
model: nn.Module, model: nn.Module,
optimizer: Optimizer, optimizer: Optional[Optimizer] = None,
criterion: Callable = None, criterion: Optional[Callable] = None,
dataloader: DataLoader = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: LRScheduler = None, lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# cast model to cuda # cast model to cuda
model = model.cuda() model = model.cuda()
@ -152,7 +152,8 @@ class TorchDDPPlugin(DPPluginBase):
# wrap the model with PyTorch DDP # wrap the model with PyTorch DDP
model = TorchDDPModel(model, **self.ddp_kwargs) model = TorchDDPModel(model, **self.ddp_kwargs)
if not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer) optimizer = OptimizerWrapper(optimizer)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -195,23 +195,24 @@ class TorchFSDPPlugin(DPPluginBase):
def configure( def configure(
self, self,
model: nn.Module, model: nn.Module,
optimizer: Optimizer, optimizer: Optional[Optimizer] = None,
criterion: Callable = None, criterion: Optional[Callable] = None,
dataloader: DataLoader = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: LRScheduler = None, lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# wrap the model with PyTorch FSDP # wrap the model with PyTorch FSDP
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs) fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
if len(optimizer.param_groups) > 1: if optimizer is not None:
warnings.warn( if len(optimizer.param_groups) > 1:
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.' warnings.warn(
) 'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) )
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
if not isinstance(optimizer, FSDPOptimizerWrapper): if not isinstance(optimizer, FSDPOptimizerWrapper):
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model) optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
return fsdp_model, optimizer, criterion, dataloader, lr_scheduler return fsdp_model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -4,12 +4,15 @@ import pytest
import torch import torch
try: try:
from diffusers import UNet2DModel import diffusers
MODELS = [UNet2DModel] MODELS = [diffusers.UNet2DModel]
HAS_REPO = True HAS_REPO = True
from packaging import version
SKIP_UNET_TEST = version.parse(diffusers.__version__) > version.parse("0.10.2")
except: except:
MODELS = [] MODELS = []
HAS_REPO = False HAS_REPO = False
SKIP_UNET_TEST = False
from test_autochunk_diffuser_utils import run_test from test_autochunk_diffuser_utils import run_test
@ -32,6 +35,10 @@ def get_data(shape: tuple) -> Tuple[List, List]:
return meta_args, concrete_args return meta_args, concrete_args
@pytest.mark.skipif(
SKIP_UNET_TEST,
reason="diffusers version > 0.10.2",
)
@pytest.mark.skipif( @pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO), not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0", reason="torch version is lower than 1.12.0",