ColossalAI/colossalai/engine/_base_engine.py

177 lines
6.3 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from colossalai.builder import build_gradient_handler
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from .schedule import BaseSchedule
class Engine:
"""Basic engine class for training and evaluation. It runs a specific process method
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
It controls a iteration in training.
:param model: The neural network model
:param optimizer: Optimizer for updating the parameters
:param step_schedule: Running schedule in :meth:`step`
:param gradient_accumulation: Steps of gradient accumulation
:param gradient_clipping: The norm of gradient clipping
:type model: Module
:type optimizer: Optimizer
:type step_schedule: BaseSchedule, optional
:type gradient_accumulation: int, optional
:type gradient_clipping: float, optional
"""
def __init__(self,
model: Module,
optimizer: Optimizer,
criterion: _Loss,
step_schedule: BaseSchedule,
gradient_handlers: list = None,
gradient_accumulation: int = 1,
gradient_clipping: float = 0.0,
):
self._model = model
self._optimizer = optimizer
self._criterion = criterion
self._schedule = step_schedule
# schedule initialize
self._schedule.initialize(model, optimizer)
# state
self.training = True # default
# gradient accumulation
assert gradient_accumulation > 0, 'gradient accumulation size must be larger than 0'
self._grad_accum_size = gradient_accumulation
self._grad_clip = gradient_clipping
self._logger = get_global_dist_logger()
# build gradient handler
self._gradient_handlers = []
if gradient_handlers is not None:
assert isinstance(gradient_handlers, list), \
f'argument gradient_handler_cfg expected type list, ' \
f'but got type {type(gradient_handlers)}'
elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
gradient_handlers = [dict(type='ZeROGradientHandler')]
self._logger.info(
"Training with zero is detected, ZeROGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
ParallelMode.DATA) > 1:
gradient_handlers = [dict(type='DataParallelGradientHandler')]
self._logger.info(
"Data parallel training is detected, DataParallelGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
if gradient_handlers is None:
self._logger.warning(
"No gradient handler is set up, please make sure you do not need "
"to all-reduce the gradients after a training step.",
ranks=[0])
else:
for cfg in gradient_handlers:
handler = build_gradient_handler(cfg, model, optimizer)
self._gradient_handlers.append(handler)
@property
def model(self):
return self._model
@property
def optimizer(self):
return self._optimizer
@property
def criterion(self):
return self._criterion
@property
def schedule(self):
return self._schedule
@property
def gradient_accumulation(self):
return self._grad_accum_size
def handle_gradient(self):
"""Handles all-reduce operations of gradients across different parallel groups.
"""
for handler in self._gradient_handlers:
handler.handle_gradient()
def train(self):
"""Sets the model to training mode.
"""
self.training = True
self._model.train()
def eval(self):
"""Sets the model to evaluation mode.
"""
self.training = False
self._model.eval()
def step(self,
data_iter,
is_last_iteration: bool = False,
return_loss=True):
"""A running step based on the schedule. Usually, it runs a training or
evaluation over a batch of dataset.
:param data_iter: Data iterator of the dataset
:param is_last_iteration: If True, this iteration is the last iteration in the epoch
:param return_loss: loss will be returned if True
:type data_iter: Iterator
:type is_last_iteration: bool, optional
:type return_loss: bool, optional
:return: (output, lablel, loss)
"""
if self.training:
self._optimizer.zero_grad()
# differentiate training and eval with grad accum
if self.training:
for i in range(self._grad_accum_size):
output, label, loss = self._schedule.forward_backward_step(
data_iter, self._model, self._criterion, self._optimizer,
forward_only=False,
grad_accum_size=self._grad_accum_size,
return_loss=return_loss)
if i == self._grad_accum_size - 1:
# all reduce gradients
self.handle_gradient()
self._schedule.optimizer_step(self._model, self._optimizer, self._grad_clip)
else:
output, label, loss = self._schedule.forward_backward_step(
data_iter, self._model, self._criterion, self._optimizer,
forward_only=True,
grad_accum_size=1,
return_loss=return_loss)
# consume the remaining dataset left out due to gradient accumulation
if is_last_iteration:
while True:
try:
_ = next(data_iter)
except StopIteration:
break
return output, label, loss