mirror of https://github.com/hpcaitech/ColossalAI
[booster] make optimizer argument optional for boost (#3993)
* feat: make optimizer optional in Booster.boost * test: skip unet test if diffusers version > 0.10.2pull/4002/head
parent
c9cff7e7fa
commit
725af3eeeb
|
@ -97,10 +97,10 @@ class Booster:
|
|||
def boost(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None,
|
||||
dataloader: DataLoader = None,
|
||||
lr_scheduler: LRScheduler = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
|
||||
"""
|
||||
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
|
||||
|
|
|
@ -115,10 +115,12 @@ class FP16TorchMixedPrecision(MixedPrecision):
|
|||
|
||||
def configure(self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||
model = TorchAMPModule(model)
|
||||
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
|
||||
if optimizer is not None:
|
||||
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
|
||||
if criterion is not None:
|
||||
criterion = TorchAMPModule(criterion)
|
||||
return model, optimizer, criterion
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Tuple
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
@ -15,7 +15,8 @@ class MixedPrecision(ABC):
|
|||
@abstractmethod
|
||||
def configure(self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||
# TODO: implement this method
|
||||
pass
|
||||
|
|
|
@ -274,11 +274,11 @@ class GeminiPlugin(DPPluginBase):
|
|||
def configure(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None,
|
||||
dataloader: DataLoader = None,
|
||||
lr_scheduler: LRScheduler = None,
|
||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
# convert model to sync bn
|
||||
|
@ -293,8 +293,12 @@ class GeminiPlugin(DPPluginBase):
|
|||
# wrap the model with Gemini
|
||||
model = GeminiModel(model, self.gemini_config, self.verbose)
|
||||
|
||||
if not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
|
||||
if optimizer is not None and \
|
||||
not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = GeminiOptimizer(model.unwrap(),
|
||||
optimizer,
|
||||
self.zero_optim_config,
|
||||
self.optim_kwargs,
|
||||
self.verbose)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
|
|
@ -197,17 +197,21 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
def configure(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None,
|
||||
dataloader: DataLoader = None,
|
||||
lr_scheduler: LRScheduler = None,
|
||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
model = LowLevelZeroModel(model, self.stage, self.precision)
|
||||
|
||||
if not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
|
||||
if optimizer is not None and \
|
||||
not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = LowLevelZeroOptimizer(model.unwrap(),
|
||||
optimizer,
|
||||
self.zero_optim_config,
|
||||
self.optim_kwargs,
|
||||
self.verbose)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Iterator, List, Tuple, Union
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
@ -38,11 +38,11 @@ class Plugin(ABC):
|
|||
def configure(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None,
|
||||
dataloader: DataLoader = None,
|
||||
lr_scheduler: LRScheduler = None,
|
||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
# implement this method
|
||||
pass
|
||||
|
||||
|
|
|
@ -138,11 +138,11 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
def configure(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None,
|
||||
dataloader: DataLoader = None,
|
||||
lr_scheduler: LRScheduler = None,
|
||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
# cast model to cuda
|
||||
model = model.cuda()
|
||||
|
||||
|
@ -152,7 +152,8 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
# wrap the model with PyTorch DDP
|
||||
model = TorchDDPModel(model, **self.ddp_kwargs)
|
||||
|
||||
if not isinstance(optimizer, OptimizerWrapper):
|
||||
if optimizer is not None and \
|
||||
not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = OptimizerWrapper(optimizer)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
|
|
@ -195,23 +195,24 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
def configure(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None,
|
||||
dataloader: DataLoader = None,
|
||||
lr_scheduler: LRScheduler = None,
|
||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
|
||||
# wrap the model with PyTorch FSDP
|
||||
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
|
||||
|
||||
if len(optimizer.param_groups) > 1:
|
||||
warnings.warn(
|
||||
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
|
||||
)
|
||||
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
|
||||
if optimizer is not None:
|
||||
if len(optimizer.param_groups) > 1:
|
||||
warnings.warn(
|
||||
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
|
||||
)
|
||||
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
|
||||
|
||||
if not isinstance(optimizer, FSDPOptimizerWrapper):
|
||||
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
|
||||
if not isinstance(optimizer, FSDPOptimizerWrapper):
|
||||
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
|
||||
|
||||
return fsdp_model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
|
|
|
@ -4,12 +4,15 @@ import pytest
|
|||
import torch
|
||||
|
||||
try:
|
||||
from diffusers import UNet2DModel
|
||||
MODELS = [UNet2DModel]
|
||||
import diffusers
|
||||
MODELS = [diffusers.UNet2DModel]
|
||||
HAS_REPO = True
|
||||
from packaging import version
|
||||
SKIP_UNET_TEST = version.parse(diffusers.__version__) > version.parse("0.10.2")
|
||||
except:
|
||||
MODELS = []
|
||||
HAS_REPO = False
|
||||
SKIP_UNET_TEST = False
|
||||
|
||||
from test_autochunk_diffuser_utils import run_test
|
||||
|
||||
|
@ -32,6 +35,10 @@ def get_data(shape: tuple) -> Tuple[List, List]:
|
|||
return meta_args, concrete_args
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
SKIP_UNET_TEST,
|
||||
reason="diffusers version > 0.10.2",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
|
|
Loading…
Reference in New Issue