set criterion as optional in colossalai initialize (#336)

pull/394/head
Frank Lee 2022-03-09 11:51:22 +08:00
parent 3213554cc2
commit 6a3188167c
3 changed files with 46 additions and 39 deletions

View File

@ -3,12 +3,13 @@ from torch.optim import Optimizer
from torch.nn.modules.loss import _Loss
from colossalai.context import Config
from .torch_amp import TorchAMPOptimizer, TorchAMPModel, TorchAMPLoss
from typing import Optional
def convert_to_torch_amp(model: nn.Module,
optimizer: Optimizer,
criterion: _Loss,
amp_config: Config):
criterion: Optional[_Loss] = None,
amp_config: Optional[Config] = None):
"""A helper function to wrap training components with Torch AMP modules
:param model: your model object
@ -16,16 +17,18 @@ def convert_to_torch_amp(model: nn.Module,
:param optimizer: your optimizer object
:type optimizer: :class:`torch.optim.Optimzer`
:param criterion: your loss function object
:type criterion: :class:`torch.nn.modules.loss._Loss`
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional
:param amp_config: configuration for different amp modes
:type amp_config: :class:`colossalai.context.Config` or dict
:type amp_config: :class:`colossalai.context.Config` or dict, optional
:return: (model, optimizer, criterion)
:rtype: Tuple
"""
model = TorchAMPModel(model)
if amp_config is None:
amp_config = dict()
optimizer = TorchAMPOptimizer(optimizer, **amp_config)
criterion = TorchAMPLoss(criterion)
if criterion:
criterion = TorchAMPLoss(criterion)
return model, optimizer, criterion

View File

@ -9,6 +9,8 @@ from torch.optim import Optimizer
from colossalai.logging import get_dist_logger
from torch import Tensor
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
from typing import Optional
from colossalai.engine.gradient_handler import BaseGradientHandler
class Engine:
@ -21,9 +23,9 @@ class Engine:
:param optimizer: Optimizer for updating the parameters
:type optimizer: ``torch.optim.Optimizer``
:param criterion: Loss function for calculating loss
:type criterion: ``torch.nn.modules.loss._Loss``
:type criterion: ``torch.nn.modules.loss._Loss``, optional
:param gradient_handlers: A list of gradient handler used in backward
:type gradient_handlers: list
:type gradient_handlers: a list of ``BaseGradientHandler``, optional
:param clip_grad_norm: The norm of gradient clipping
:type clip_grad_norm: float, optional
:param ophook_list: List of ophook
@ -31,13 +33,14 @@ class Engine:
:param verbose: whether to display log info
:type verbose: bool
"""
def __init__(self,
model: Module,
optimizer: Optimizer,
criterion: _Loss,
gradient_handlers: List = None,
criterion: Optional[_Loss] = None,
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
clip_grad_norm: float = 0.0,
ophook_list: List[BaseOpHook] = [],
ophook_list: Optional[List[BaseOpHook]] = None,
verbose: bool = True):
self._model = model
self._optimizer = optimizer
@ -47,7 +50,7 @@ class Engine:
self._logger = get_dist_logger()
# state
self.training = True # default
self.training = True # default
# build gradient handler
if gradient_handlers:
@ -55,7 +58,10 @@ class Engine:
else:
self._gradient_handlers = []
self._ophook_list = ophook_list
if ophook_list is None:
self._ophook_list = []
else:
self._ophook_list = ophook_list
register_ophooks_recursively(self._model, self._ophook_list)
@property

View File

@ -27,7 +27,7 @@ from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param)
from colossalai.zero import convert_to_zero, ShardedOptimizer
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
from colossalai.engine.ophooks import BaseOpHook
def get_default_parser():
@ -216,15 +216,14 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose=verbose)
def initialize(model: Union[nn.Module, List[nn.Module]],
optimizer: Union[Optimizer, List[Optimizer]],
criterion: Union[_Loss, List[_Loss]],
train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
lr_scheduler: _LRScheduler = None,
ophooks: List[BaseOpHook] = [],
verbose: bool = True
) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
def initialize(model: nn.Module,
optimizer: Optimizer,
criterion: Optional[_Loss] = None,
train_dataloader: Optional[Iterable] = None,
test_dataloader: Optional[Iterable] = None,
lr_scheduler: Optional[_LRScheduler] = None,
ophooks: Optional[List[BaseOpHook]] = None,
verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
"""Core function to wrap the essential training components with our functionality based on the config which is
loaded into gpc.config.
@ -233,12 +232,12 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
:param optimizer: Your optimizer instance
:type optimizer: :class:`torch.optim.optimizer.Optimizer`
:param criterion: Your criterion instance
:type criterion: :class:`torch.nn.modules.loss._Loss`
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional
:param train_dataloader: Dataloader for training
:type train_dataloader: :class:`torch.utils.data.DataLoader`, optional
:param test_dataloader: Dataloader for testing
:type test_dataloader: :class:`torch.utils.data.DataLoader`, optional
:param lr_scheduler: Your lr scheduler instance
:param lr_scheduler: Your lr scheduler instance, optional
:type lr_scheduler: :class:`torch.nn.lr_scheduler._LRScheduler`, optional
:param verbose: Whether to print logs
:type verbose: bool, optional
@ -399,20 +398,19 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
# gradient accumulation
grad_accum_size = gpc.config.get('gradient_accumulation', None)
if grad_accum_size is not None:
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(model=model,
optimizer=optimizer,
dataloader=train_dataloader,
accumulate_size=grad_accum_size,
gradient_handlers=gradient_handlers,
lr_scheduler=lr_scheduler)
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(
model=model,
optimizer=optimizer,
dataloader=train_dataloader,
accumulate_size=grad_accum_size,
gradient_handlers=gradient_handlers,
lr_scheduler=lr_scheduler)
engine = Engine(
model=model,
optimizer=optimizer,
criterion=criterion,
gradient_handlers=gradient_handlers,
clip_grad_norm=clip_grad_norm,
ophook_list=ophooks
)
engine = Engine(model=model,
optimizer=optimizer,
criterion=criterion,
gradient_handlers=gradient_handlers,
clip_grad_norm=clip_grad_norm,
ophook_list=ophooks)
return engine, train_dataloader, test_dataloader, lr_scheduler