mirror of https://github.com/hpcaitech/ColossalAI
[booster] make optimizer argument optional for boost (#3993)
* feat: make optimizer optional in Booster.boost * test: skip unet test if diffusers version > 0.10.2pull/4002/head
parent
c9cff7e7fa
commit
725af3eeeb
|
@ -97,10 +97,10 @@ class Booster:
|
||||||
def boost(
|
def boost(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optional[Optimizer] = None,
|
||||||
criterion: Callable = None,
|
criterion: Optional[Callable] = None,
|
||||||
dataloader: DataLoader = None,
|
dataloader: Optional[DataLoader] = None,
|
||||||
lr_scheduler: LRScheduler = None,
|
lr_scheduler: Optional[LRScheduler] = None,
|
||||||
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
|
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
|
||||||
"""
|
"""
|
||||||
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
|
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
|
||||||
|
|
|
@ -115,10 +115,12 @@ class FP16TorchMixedPrecision(MixedPrecision):
|
||||||
|
|
||||||
def configure(self,
|
def configure(self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optional[Optimizer] = None,
|
||||||
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
criterion: Optional[Callable] = None,
|
||||||
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||||
model = TorchAMPModule(model)
|
model = TorchAMPModule(model)
|
||||||
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
|
if optimizer is not None:
|
||||||
|
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
|
||||||
if criterion is not None:
|
if criterion is not None:
|
||||||
criterion = TorchAMPModule(criterion)
|
criterion = TorchAMPModule(criterion)
|
||||||
return model, optimizer, criterion
|
return model, optimizer, criterion
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, Optional, Tuple
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
@ -15,7 +15,8 @@ class MixedPrecision(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def configure(self,
|
def configure(self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optional[Optimizer] = None,
|
||||||
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
criterion: Optional[Callable] = None,
|
||||||
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||||
# TODO: implement this method
|
# TODO: implement this method
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -274,11 +274,11 @@ class GeminiPlugin(DPPluginBase):
|
||||||
def configure(
|
def configure(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optional[Optimizer] = None,
|
||||||
criterion: Callable = None,
|
criterion: Optional[Callable] = None,
|
||||||
dataloader: DataLoader = None,
|
dataloader: Optional[DataLoader] = None,
|
||||||
lr_scheduler: LRScheduler = None,
|
lr_scheduler: Optional[LRScheduler] = None,
|
||||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||||
|
|
||||||
if not isinstance(model, ModelWrapper):
|
if not isinstance(model, ModelWrapper):
|
||||||
# convert model to sync bn
|
# convert model to sync bn
|
||||||
|
@ -293,8 +293,12 @@ class GeminiPlugin(DPPluginBase):
|
||||||
# wrap the model with Gemini
|
# wrap the model with Gemini
|
||||||
model = GeminiModel(model, self.gemini_config, self.verbose)
|
model = GeminiModel(model, self.gemini_config, self.verbose)
|
||||||
|
|
||||||
if not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and \
|
||||||
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
|
not isinstance(optimizer, OptimizerWrapper):
|
||||||
|
optimizer = GeminiOptimizer(model.unwrap(),
|
||||||
|
optimizer,
|
||||||
|
self.zero_optim_config,
|
||||||
|
self.optim_kwargs,
|
||||||
self.verbose)
|
self.verbose)
|
||||||
|
|
||||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||||
|
|
|
@ -197,17 +197,21 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
def configure(
|
def configure(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optional[Optimizer] = None,
|
||||||
criterion: Callable = None,
|
criterion: Optional[Callable] = None,
|
||||||
dataloader: DataLoader = None,
|
dataloader: Optional[DataLoader] = None,
|
||||||
lr_scheduler: LRScheduler = None,
|
lr_scheduler: Optional[LRScheduler] = None,
|
||||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||||
|
|
||||||
if not isinstance(model, ModelWrapper):
|
if not isinstance(model, ModelWrapper):
|
||||||
model = LowLevelZeroModel(model, self.stage, self.precision)
|
model = LowLevelZeroModel(model, self.stage, self.precision)
|
||||||
|
|
||||||
if not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and \
|
||||||
optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
|
not isinstance(optimizer, OptimizerWrapper):
|
||||||
|
optimizer = LowLevelZeroOptimizer(model.unwrap(),
|
||||||
|
optimizer,
|
||||||
|
self.zero_optim_config,
|
||||||
|
self.optim_kwargs,
|
||||||
self.verbose)
|
self.verbose)
|
||||||
|
|
||||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, Iterator, List, Tuple, Union
|
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
@ -38,11 +38,11 @@ class Plugin(ABC):
|
||||||
def configure(
|
def configure(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optional[Optimizer] = None,
|
||||||
criterion: Callable = None,
|
criterion: Optional[Callable] = None,
|
||||||
dataloader: DataLoader = None,
|
dataloader: Optional[DataLoader] = None,
|
||||||
lr_scheduler: LRScheduler = None,
|
lr_scheduler: Optional[LRScheduler] = None,
|
||||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||||
# implement this method
|
# implement this method
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -138,11 +138,11 @@ class TorchDDPPlugin(DPPluginBase):
|
||||||
def configure(
|
def configure(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optional[Optimizer] = None,
|
||||||
criterion: Callable = None,
|
criterion: Optional[Callable] = None,
|
||||||
dataloader: DataLoader = None,
|
dataloader: Optional[DataLoader] = None,
|
||||||
lr_scheduler: LRScheduler = None,
|
lr_scheduler: Optional[LRScheduler] = None,
|
||||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||||
# cast model to cuda
|
# cast model to cuda
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
|
@ -152,7 +152,8 @@ class TorchDDPPlugin(DPPluginBase):
|
||||||
# wrap the model with PyTorch DDP
|
# wrap the model with PyTorch DDP
|
||||||
model = TorchDDPModel(model, **self.ddp_kwargs)
|
model = TorchDDPModel(model, **self.ddp_kwargs)
|
||||||
|
|
||||||
if not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and \
|
||||||
|
not isinstance(optimizer, OptimizerWrapper):
|
||||||
optimizer = OptimizerWrapper(optimizer)
|
optimizer = OptimizerWrapper(optimizer)
|
||||||
|
|
||||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||||
|
|
|
@ -195,23 +195,24 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||||
def configure(
|
def configure(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optional[Optimizer] = None,
|
||||||
criterion: Callable = None,
|
criterion: Optional[Callable] = None,
|
||||||
dataloader: DataLoader = None,
|
dataloader: Optional[DataLoader] = None,
|
||||||
lr_scheduler: LRScheduler = None,
|
lr_scheduler: Optional[LRScheduler] = None,
|
||||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||||
|
|
||||||
# wrap the model with PyTorch FSDP
|
# wrap the model with PyTorch FSDP
|
||||||
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
|
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
|
||||||
|
|
||||||
if len(optimizer.param_groups) > 1:
|
if optimizer is not None:
|
||||||
warnings.warn(
|
if len(optimizer.param_groups) > 1:
|
||||||
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
|
warnings.warn(
|
||||||
)
|
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
|
||||||
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
|
)
|
||||||
|
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
|
||||||
|
|
||||||
if not isinstance(optimizer, FSDPOptimizerWrapper):
|
if not isinstance(optimizer, FSDPOptimizerWrapper):
|
||||||
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
|
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
|
||||||
|
|
||||||
return fsdp_model, optimizer, criterion, dataloader, lr_scheduler
|
return fsdp_model, optimizer, criterion, dataloader, lr_scheduler
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,15 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from diffusers import UNet2DModel
|
import diffusers
|
||||||
MODELS = [UNet2DModel]
|
MODELS = [diffusers.UNet2DModel]
|
||||||
HAS_REPO = True
|
HAS_REPO = True
|
||||||
|
from packaging import version
|
||||||
|
SKIP_UNET_TEST = version.parse(diffusers.__version__) > version.parse("0.10.2")
|
||||||
except:
|
except:
|
||||||
MODELS = []
|
MODELS = []
|
||||||
HAS_REPO = False
|
HAS_REPO = False
|
||||||
|
SKIP_UNET_TEST = False
|
||||||
|
|
||||||
from test_autochunk_diffuser_utils import run_test
|
from test_autochunk_diffuser_utils import run_test
|
||||||
|
|
||||||
|
@ -32,6 +35,10 @@ def get_data(shape: tuple) -> Tuple[List, List]:
|
||||||
return meta_args, concrete_args
|
return meta_args, concrete_args
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
SKIP_UNET_TEST,
|
||||||
|
reason="diffusers version > 0.10.2",
|
||||||
|
)
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||||
reason="torch version is lower than 1.12.0",
|
reason="torch version is lower than 1.12.0",
|
||||||
|
|
Loading…
Reference in New Issue