ColossalAI/colossalai/engine/_base_engine.py

215 lines
7.6 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import List, Iterable
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from colossalai.logging import get_dist_logger
from torch import Tensor
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
from typing import Optional, Type
from colossalai.engine.gradient_handler import BaseGradientHandler
from colossalai.logging import get_dist_logger
class Engine:
"""Basic engine class for training and evaluation. It runs a specific process method
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
It controls a iteration in training.
Args:
model (``torch.nn.Module``): The neural network model.
optimizer (``colossalai.nn.optimizer.ColossalaiOptimizer``): Optimizer for updating the parameters.
criterion (``torch.nn.modules.loss._Loss``, optional): Loss function for calculating loss.
gradient_handlers (List[``BaseGradientHandler``], optional): A list of gradient handler used in backward.
clip_grad_norm (float, optional): The norm of gradient clipping.
ophook_list (list): List of ophook.
verbose (bool): whether to display log info.
schedule (''BaseSchedule''): Runtime schedule.
Examples:
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
>>> model = ...
>>> criterion = ...
>>> optimizer = ...
>>> train_dataloader = ...
>>> engine, _, _, _ = colossalai.initialize(model, optimizer, criterion)
>>> engine.train()
>>> for inputs, labels in train_dataloader
>>> # set gradients to zero
>>> engine.zero_grad()
>>> # run forward pass
>>> outputs = engine(inputs)
>>> # compute loss value and run backward pass
>>> loss = engine.criterion(outputs, labels)
>>> engine.backward(loss)
>>> # update parameters
>>> engine.step()
The example of using Engine in training could be find in
`Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_. and
`Run resnet cifar10 with engine <https://github.com/hpcaitech/ColossalAI-Examples/blob/main/image/resnet/run_resnet_cifar10_with_engine.py>`_.
"""
def __init__(self,
model: Module,
optimizer: "ColossalaiOptimizer",
criterion: Optional[_Loss] = None,
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
clip_grad_norm: float = 0.0,
ophook_list: Optional[List[BaseOpHook]] = None,
verbose: bool = True,
schedule: Optional[BaseSchedule] = None):
self._model = model
self._optimizer = optimizer
self._criterion = criterion
self._clip_grad_norm = clip_grad_norm
self._verbose = verbose
self._logger = get_dist_logger()
# state
self.training = True # default
# build gradient handler
if gradient_handlers:
self._gradient_handlers = gradient_handlers
else:
self._gradient_handlers = []
if ophook_list is None:
self._ophook_list = []
else:
self._ophook_list = ophook_list
# build schedule
if schedule:
assert isinstance(schedule, BaseSchedule), \
f'expected schedule to be of type BaseSchedule, but got {type(schedule)}'
self._schedule = schedule
else:
self._schedule = NonPipelineSchedule()
if self.uses_pipeline:
self._schedule.pre_processing(self)
#register hook if any
if len(self._ophook_list) > 0:
register_ophooks_recursively(self._model, self._ophook_list)
@property
def ophooks(self):
"""show current activated ophooks"""
return self._ophook_list
@property
def model(self):
"""Model attached to the engine"""
return self._model
@property
def optimizer(self):
"""Optimizer attached to the engine"""
return self._optimizer
@property
def criterion(self):
"""Criterion attached to the engine"""
return self._criterion
@property
def schedule(self):
"""Schedule attached to the engine"""
return self._schedule
@property
def uses_pipeline(self):
"""show the pipeline parallel used or not"""
return isinstance(self._schedule, (PipelineSchedule, InterleavedPipelineSchedule))
def add_hook(self, ophook: Type[BaseOpHook]) -> None:
"""add necessary hook"""
# whether this hook exist
for h in self._ophook_list:
if type(h) == type(ophook):
logger = get_dist_logger()
logger.warning(f"duplicate hooks, at least two instance of {type(ophook)}")
self._ophook_list.append(ophook)
register_ophooks_recursively(self._model, self._ophook_list)
def remove_hook(self, ophook: Type[BaseOpHook]) -> None:
"""remove hook"""
logger = get_dist_logger()
logger.warning(f"removing hooks is currently not supported")
def zero_grad(self):
"""Set the gradient of parameters to zero
"""
self.optimizer.zero_grad()
def step(self):
"""Execute parameter update
"""
self._all_reduce_gradients()
self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
return self.optimizer.step()
def backward(self, loss: Tensor):
"""Start backward propagation given the loss value computed by a loss function.
Args:
loss (:class:`torch.Tensor`): Loss value computed by a loss function.
"""
ret = self.optimizer.backward(loss)
for ophook in self._ophook_list:
ophook.post_iter()
return ret
def backward_by_grad(self, tensor, grad):
"""Start backward propagation given the gradient of the output tensor.
Args:
tensor (:class:`torch.Tensor`): Output tensor.
grad (:class:`torch.Tensor`): Gradient passed back to the output.
"""
ret = self.optimizer.backward_by_grad(tensor, grad)
for ophook in self._ophook_list:
ophook.post_iter()
return ret
def __call__(self, *args, **kwargs):
"""Run the forward step for the model.
Returns:
Tuple[:class:`torch.Tensor`] or :class:`torch.Tensor`: Output of the model.
"""
return self.model(*args, **kwargs)
def _all_reduce_gradients(self):
"""Handles all-reduce operations of gradients across different parallel groups.
"""
for handler in self._gradient_handlers:
handler.handle_gradient()
def execute_schedule(self, data_iter: Iterable, **kwargs):
"""Run the forward, loss computation, and backward for the model.
Returns a tuple of (output, label, loss).
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
"""
output, label, loss = self._schedule.forward_backward_step(self, data_iter, **kwargs)
return output, label, loss
def train(self):
"""Sets the model to training mode.
"""
self.training = True
self._model.train()
def eval(self):
"""Sets the model to evaluation mode.
"""
self.training = False
self._model.eval()