mirror of https://github.com/hpcaitech/ColossalAI
set criterion as optional in colossalai initialize (#336)
parent
3213554cc2
commit
6a3188167c
|
@ -3,12 +3,13 @@ from torch.optim import Optimizer
|
|||
from torch.nn.modules.loss import _Loss
|
||||
from colossalai.context import Config
|
||||
from .torch_amp import TorchAMPOptimizer, TorchAMPModel, TorchAMPLoss
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def convert_to_torch_amp(model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: _Loss,
|
||||
amp_config: Config):
|
||||
criterion: Optional[_Loss] = None,
|
||||
amp_config: Optional[Config] = None):
|
||||
"""A helper function to wrap training components with Torch AMP modules
|
||||
|
||||
:param model: your model object
|
||||
|
@ -16,16 +17,18 @@ def convert_to_torch_amp(model: nn.Module,
|
|||
:param optimizer: your optimizer object
|
||||
:type optimizer: :class:`torch.optim.Optimzer`
|
||||
:param criterion: your loss function object
|
||||
:type criterion: :class:`torch.nn.modules.loss._Loss`
|
||||
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional
|
||||
:param amp_config: configuration for different amp modes
|
||||
:type amp_config: :class:`colossalai.context.Config` or dict
|
||||
|
||||
:type amp_config: :class:`colossalai.context.Config` or dict, optional
|
||||
:return: (model, optimizer, criterion)
|
||||
:rtype: Tuple
|
||||
"""
|
||||
model = TorchAMPModel(model)
|
||||
if amp_config is None:
|
||||
amp_config = dict()
|
||||
optimizer = TorchAMPOptimizer(optimizer, **amp_config)
|
||||
criterion = TorchAMPLoss(criterion)
|
||||
if criterion:
|
||||
criterion = TorchAMPLoss(criterion)
|
||||
return model, optimizer, criterion
|
||||
|
||||
|
||||
|
|
|
@ -9,6 +9,8 @@ from torch.optim import Optimizer
|
|||
from colossalai.logging import get_dist_logger
|
||||
from torch import Tensor
|
||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
||||
from typing import Optional
|
||||
from colossalai.engine.gradient_handler import BaseGradientHandler
|
||||
|
||||
|
||||
class Engine:
|
||||
|
@ -21,9 +23,9 @@ class Engine:
|
|||
:param optimizer: Optimizer for updating the parameters
|
||||
:type optimizer: ``torch.optim.Optimizer``
|
||||
:param criterion: Loss function for calculating loss
|
||||
:type criterion: ``torch.nn.modules.loss._Loss``
|
||||
:type criterion: ``torch.nn.modules.loss._Loss``, optional
|
||||
:param gradient_handlers: A list of gradient handler used in backward
|
||||
:type gradient_handlers: list
|
||||
:type gradient_handlers: a list of ``BaseGradientHandler``, optional
|
||||
:param clip_grad_norm: The norm of gradient clipping
|
||||
:type clip_grad_norm: float, optional
|
||||
:param ophook_list: List of ophook
|
||||
|
@ -31,13 +33,14 @@ class Engine:
|
|||
:param verbose: whether to display log info
|
||||
:type verbose: bool
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: _Loss,
|
||||
gradient_handlers: List = None,
|
||||
criterion: Optional[_Loss] = None,
|
||||
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
|
||||
clip_grad_norm: float = 0.0,
|
||||
ophook_list: List[BaseOpHook] = [],
|
||||
ophook_list: Optional[List[BaseOpHook]] = None,
|
||||
verbose: bool = True):
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
|
@ -47,7 +50,7 @@ class Engine:
|
|||
self._logger = get_dist_logger()
|
||||
|
||||
# state
|
||||
self.training = True # default
|
||||
self.training = True # default
|
||||
|
||||
# build gradient handler
|
||||
if gradient_handlers:
|
||||
|
@ -55,7 +58,10 @@ class Engine:
|
|||
else:
|
||||
self._gradient_handlers = []
|
||||
|
||||
self._ophook_list = ophook_list
|
||||
if ophook_list is None:
|
||||
self._ophook_list = []
|
||||
else:
|
||||
self._ophook_list = ophook_list
|
||||
register_ophooks_recursively(self._model, self._ophook_list)
|
||||
|
||||
@property
|
||||
|
|
|
@ -27,7 +27,7 @@ from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
|||
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
|
||||
sync_model_param)
|
||||
from colossalai.zero import convert_to_zero, ShardedOptimizer
|
||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
||||
from colossalai.engine.ophooks import BaseOpHook
|
||||
|
||||
|
||||
def get_default_parser():
|
||||
|
@ -216,15 +216,14 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
|
|||
verbose=verbose)
|
||||
|
||||
|
||||
def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||
optimizer: Union[Optimizer, List[Optimizer]],
|
||||
criterion: Union[_Loss, List[_Loss]],
|
||||
train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
|
||||
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
|
||||
lr_scheduler: _LRScheduler = None,
|
||||
ophooks: List[BaseOpHook] = [],
|
||||
verbose: bool = True
|
||||
) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
|
||||
def initialize(model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Optional[_Loss] = None,
|
||||
train_dataloader: Optional[Iterable] = None,
|
||||
test_dataloader: Optional[Iterable] = None,
|
||||
lr_scheduler: Optional[_LRScheduler] = None,
|
||||
ophooks: Optional[List[BaseOpHook]] = None,
|
||||
verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
|
||||
"""Core function to wrap the essential training components with our functionality based on the config which is
|
||||
loaded into gpc.config.
|
||||
|
||||
|
@ -233,12 +232,12 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||
:param optimizer: Your optimizer instance
|
||||
:type optimizer: :class:`torch.optim.optimizer.Optimizer`
|
||||
:param criterion: Your criterion instance
|
||||
:type criterion: :class:`torch.nn.modules.loss._Loss`
|
||||
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional
|
||||
:param train_dataloader: Dataloader for training
|
||||
:type train_dataloader: :class:`torch.utils.data.DataLoader`, optional
|
||||
:param test_dataloader: Dataloader for testing
|
||||
:type test_dataloader: :class:`torch.utils.data.DataLoader`, optional
|
||||
:param lr_scheduler: Your lr scheduler instance
|
||||
:param lr_scheduler: Your lr scheduler instance, optional
|
||||
:type lr_scheduler: :class:`torch.nn.lr_scheduler._LRScheduler`, optional
|
||||
:param verbose: Whether to print logs
|
||||
:type verbose: bool, optional
|
||||
|
@ -399,20 +398,19 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||
# gradient accumulation
|
||||
grad_accum_size = gpc.config.get('gradient_accumulation', None)
|
||||
if grad_accum_size is not None:
|
||||
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=train_dataloader,
|
||||
accumulate_size=grad_accum_size,
|
||||
gradient_handlers=gradient_handlers,
|
||||
lr_scheduler=lr_scheduler)
|
||||
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=train_dataloader,
|
||||
accumulate_size=grad_accum_size,
|
||||
gradient_handlers=gradient_handlers,
|
||||
lr_scheduler=lr_scheduler)
|
||||
|
||||
engine = Engine(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
gradient_handlers=gradient_handlers,
|
||||
clip_grad_norm=clip_grad_norm,
|
||||
ophook_list=ophooks
|
||||
)
|
||||
engine = Engine(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
gradient_handlers=gradient_handlers,
|
||||
clip_grad_norm=clip_grad_norm,
|
||||
ophook_list=ophooks)
|
||||
|
||||
return engine, train_dataloader, test_dataloader, lr_scheduler
|
||||
|
|
Loading…
Reference in New Issue