mirror of https://github.com/hpcaitech/ColossalAI
177 lines
6.3 KiB
Python
177 lines
6.3 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from torch.nn import Module
|
|
from torch.nn.modules.loss import _Loss
|
|
from torch.optim import Optimizer
|
|
|
|
from colossalai.builder import build_gradient_handler
|
|
from colossalai.context import ParallelMode
|
|
from colossalai.core import global_context as gpc
|
|
from colossalai.logging import get_global_dist_logger
|
|
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
|
ZeroRedundancyOptimizer_Level_3)
|
|
from .schedule import BaseSchedule
|
|
|
|
|
|
class Engine:
|
|
"""Basic engine class for training and evaluation. It runs a specific process method
|
|
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
|
|
It controls a iteration in training.
|
|
|
|
:param model: The neural network model
|
|
:param optimizer: Optimizer for updating the parameters
|
|
:param step_schedule: Running schedule in :meth:`step`
|
|
:param gradient_accumulation: Steps of gradient accumulation
|
|
:param gradient_clipping: The norm of gradient clipping
|
|
:type model: Module
|
|
:type optimizer: Optimizer
|
|
:type step_schedule: BaseSchedule, optional
|
|
:type gradient_accumulation: int, optional
|
|
:type gradient_clipping: float, optional
|
|
"""
|
|
|
|
def __init__(self,
|
|
model: Module,
|
|
optimizer: Optimizer,
|
|
criterion: _Loss,
|
|
step_schedule: BaseSchedule,
|
|
gradient_handlers: list = None,
|
|
gradient_accumulation: int = 1,
|
|
gradient_clipping: float = 0.0,
|
|
):
|
|
self._model = model
|
|
self._optimizer = optimizer
|
|
self._criterion = criterion
|
|
self._schedule = step_schedule
|
|
|
|
# schedule initialize
|
|
self._schedule.initialize(model, optimizer)
|
|
|
|
# state
|
|
self.training = True # default
|
|
|
|
# gradient accumulation
|
|
assert gradient_accumulation > 0, 'gradient accumulation size must be larger than 0'
|
|
self._grad_accum_size = gradient_accumulation
|
|
self._grad_clip = gradient_clipping
|
|
self._logger = get_global_dist_logger()
|
|
|
|
# build gradient handler
|
|
self._gradient_handlers = []
|
|
|
|
if gradient_handlers is not None:
|
|
assert isinstance(gradient_handlers, list), \
|
|
f'argument gradient_handler_cfg expected type list, ' \
|
|
f'but got type {type(gradient_handlers)}'
|
|
elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
|
|
ZeroRedundancyOptimizer_Level_3)):
|
|
gradient_handlers = [dict(type='ZeROGradientHandler')]
|
|
self._logger.info(
|
|
"Training with zero is detected, ZeROGradientHandler is automatically "
|
|
"added even though not specified in the configuration",
|
|
ranks=[0])
|
|
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
|
|
ParallelMode.DATA) > 1:
|
|
gradient_handlers = [dict(type='DataParallelGradientHandler')]
|
|
self._logger.info(
|
|
"Data parallel training is detected, DataParallelGradientHandler is automatically "
|
|
"added even though not specified in the configuration",
|
|
ranks=[0])
|
|
|
|
if gradient_handlers is None:
|
|
self._logger.warning(
|
|
"No gradient handler is set up, please make sure you do not need "
|
|
"to all-reduce the gradients after a training step.",
|
|
ranks=[0])
|
|
else:
|
|
for cfg in gradient_handlers:
|
|
handler = build_gradient_handler(cfg, model, optimizer)
|
|
self._gradient_handlers.append(handler)
|
|
|
|
@property
|
|
def model(self):
|
|
return self._model
|
|
|
|
@property
|
|
def optimizer(self):
|
|
return self._optimizer
|
|
|
|
@property
|
|
def criterion(self):
|
|
return self._criterion
|
|
|
|
@property
|
|
def schedule(self):
|
|
return self._schedule
|
|
|
|
@property
|
|
def gradient_accumulation(self):
|
|
return self._grad_accum_size
|
|
|
|
def handle_gradient(self):
|
|
"""Handles all-reduce operations of gradients across different parallel groups.
|
|
"""
|
|
for handler in self._gradient_handlers:
|
|
handler.handle_gradient()
|
|
|
|
def train(self):
|
|
"""Sets the model to training mode.
|
|
"""
|
|
self.training = True
|
|
self._model.train()
|
|
|
|
def eval(self):
|
|
"""Sets the model to evaluation mode.
|
|
"""
|
|
self.training = False
|
|
self._model.eval()
|
|
|
|
def step(self,
|
|
data_iter,
|
|
is_last_iteration: bool = False,
|
|
return_loss=True):
|
|
"""A running step based on the schedule. Usually, it runs a training or
|
|
evaluation over a batch of dataset.
|
|
|
|
:param data_iter: Data iterator of the dataset
|
|
:param is_last_iteration: If True, this iteration is the last iteration in the epoch
|
|
:param return_loss: loss will be returned if True
|
|
:type data_iter: Iterator
|
|
:type is_last_iteration: bool, optional
|
|
:type return_loss: bool, optional
|
|
:return: (output, lablel, loss)
|
|
"""
|
|
if self.training:
|
|
self._optimizer.zero_grad()
|
|
|
|
# differentiate training and eval with grad accum
|
|
if self.training:
|
|
for i in range(self._grad_accum_size):
|
|
output, label, loss = self._schedule.forward_backward_step(
|
|
data_iter, self._model, self._criterion, self._optimizer,
|
|
forward_only=False,
|
|
grad_accum_size=self._grad_accum_size,
|
|
return_loss=return_loss)
|
|
|
|
if i == self._grad_accum_size - 1:
|
|
# all reduce gradients
|
|
self.handle_gradient()
|
|
self._schedule.optimizer_step(self._model, self._optimizer, self._grad_clip)
|
|
else:
|
|
output, label, loss = self._schedule.forward_backward_step(
|
|
data_iter, self._model, self._criterion, self._optimizer,
|
|
forward_only=True,
|
|
grad_accum_size=1,
|
|
return_loss=return_loss)
|
|
|
|
# consume the remaining dataset left out due to gradient accumulation
|
|
if is_last_iteration:
|
|
while True:
|
|
try:
|
|
_ = next(data_iter)
|
|
except StopIteration:
|
|
break
|
|
|
|
return output, label, loss
|