mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
154 lines
4.9 KiB
154 lines
4.9 KiB
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from typing import List
|
|
from torch.nn import Module
|
|
from torch.nn.modules.loss import _Loss
|
|
from torch.optim import Optimizer
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
from torch import Tensor
|
|
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
|
from typing import Optional
|
|
from colossalai.engine.gradient_handler import BaseGradientHandler
|
|
|
|
|
|
class Engine:
|
|
"""Basic engine class for training and evaluation. It runs a specific process method
|
|
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
|
|
It controls a iteration in training.
|
|
|
|
:param model: The neural network model
|
|
:type model: ``torch.nn.Module``
|
|
:param optimizer: Optimizer for updating the parameters
|
|
:type optimizer: ``torch.optim.Optimizer``
|
|
:param criterion: Loss function for calculating loss
|
|
:type criterion: ``torch.nn.modules.loss._Loss``, optional
|
|
:param gradient_handlers: A list of gradient handler used in backward
|
|
:type gradient_handlers: a list of ``BaseGradientHandler``, optional
|
|
:param clip_grad_norm: The norm of gradient clipping
|
|
:type clip_grad_norm: float, optional
|
|
:param ophook_list: List of ophook
|
|
:type ophook_list: list
|
|
:param verbose: whether to display log info
|
|
:type verbose: bool
|
|
"""
|
|
|
|
def __init__(self,
|
|
model: Module,
|
|
optimizer: Optimizer,
|
|
criterion: Optional[_Loss] = None,
|
|
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
|
|
clip_grad_norm: float = 0.0,
|
|
ophook_list: Optional[List[BaseOpHook]] = None,
|
|
verbose: bool = True):
|
|
self._model = model
|
|
self._optimizer = optimizer
|
|
self._criterion = criterion
|
|
self._clip_grad_norm = clip_grad_norm
|
|
self._verbose = verbose
|
|
self._logger = get_dist_logger()
|
|
|
|
# state
|
|
self.training = True # default
|
|
|
|
# build gradient handler
|
|
if gradient_handlers:
|
|
self._gradient_handlers = gradient_handlers
|
|
else:
|
|
self._gradient_handlers = []
|
|
|
|
if ophook_list is None:
|
|
self._ophook_list = []
|
|
else:
|
|
self._ophook_list = ophook_list
|
|
register_ophooks_recursively(self._model, self._ophook_list)
|
|
|
|
@property
|
|
def model(self):
|
|
"""Model attached to the engine"""
|
|
return self._model
|
|
|
|
@property
|
|
def optimizer(self):
|
|
"""Optimizer attached to the engine"""
|
|
return self._optimizer
|
|
|
|
@property
|
|
def criterion(self):
|
|
"""Criterion attached to the engine"""
|
|
return self._criterion
|
|
|
|
def zero_grad(self):
|
|
"""Set the gradient of parameters to zero
|
|
"""
|
|
self.optimizer.zero_grad()
|
|
|
|
def step(self):
|
|
"""Execute parameter update
|
|
"""
|
|
self._all_reduce_gradients()
|
|
self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
|
|
return self.optimizer.step()
|
|
|
|
def backward(self, loss: Tensor):
|
|
"""Start backward propagation given the loss value computed by a loss function
|
|
|
|
:param loss: Loss value computed by a loss function
|
|
:type loss: :class:`torch.Tensor`
|
|
"""
|
|
ret = self.optimizer.backward(loss)
|
|
for ophook in self._ophook_list:
|
|
ophook.post_iter()
|
|
return ret
|
|
|
|
def backward_by_grad(self, tensor, grad):
|
|
"""Start backward propagation given the gradient of the output tensor
|
|
|
|
:param tensor: Output tensor
|
|
:type tensor: :class:`torch.Tensor`
|
|
:param grad: Gradient passed back to the output
|
|
:type grad: :class:`torch.Tensor`
|
|
"""
|
|
ret = self.optimizer.backward_by_grad(tensor, grad)
|
|
for ophook in self._ophook_list:
|
|
ophook.post_iter()
|
|
return ret
|
|
|
|
def calc_loss(self, *args, **kwargs):
|
|
"""Compute the loss value
|
|
|
|
:param args: Args used in criterion function
|
|
:param kwargs: Kwargs used in criterion function
|
|
|
|
:return: The loss value
|
|
:rtype: :class:`torch.Tensor`
|
|
"""
|
|
return self.criterion(*args, **kwargs)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
"""Run the forward step for the model
|
|
|
|
:return: Output the model
|
|
:rtype: Tuple[:class:`torch.Tensor`] or :class:`torch.Tensor`
|
|
"""
|
|
return self.model(*args, **kwargs)
|
|
|
|
def _all_reduce_gradients(self):
|
|
"""Handles all-reduce operations of gradients across different parallel groups.
|
|
"""
|
|
for handler in self._gradient_handlers:
|
|
handler.handle_gradient()
|
|
|
|
def train(self):
|
|
"""Sets the model to training mode.
|
|
"""
|
|
self.training = True
|
|
self._model.train()
|
|
|
|
def eval(self):
|
|
"""Sets the model to evaluation mode.
|
|
"""
|
|
self.training = False
|
|
self._model.eval()
|