mirror of https://github.com/hpcaitech/ColossalAI
198 lines
7.5 KiB
Python
198 lines
7.5 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from typing import Optional
|
|
|
|
from colossalai.builder import build_gradient_handler
|
|
from colossalai.context import ParallelMode
|
|
from colossalai.core import global_context as gpc
|
|
from colossalai.logging import get_global_dist_logger
|
|
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
|
ZeroRedundancyOptimizer_Level_3)
|
|
from torch.nn import Module
|
|
from torch.nn.modules.loss import _Loss
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
from torch.utils.data import DataLoader
|
|
|
|
from .schedule import BaseSchedule, NoPipelineSchedule
|
|
|
|
|
|
class Engine:
|
|
"""Basic engine class for training and evaluation. It runs a specific process method
|
|
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
|
|
|
|
:param train_dataloader: Dataloader in training
|
|
:param test_dataloader: Dataloader in evaluation
|
|
:param model: The neural network model
|
|
:param criterion: Criterion for calculating loss
|
|
:param optimizer: Optimizer for updating the parameters
|
|
:param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation
|
|
:param schedule: Running schedule in :meth:`step`
|
|
:type train_dataloader: DataLoader, optional
|
|
:type test_dataloader: DataLoader, optional
|
|
:type model: Module
|
|
:type criterion: _Loss, optional
|
|
:type optimizer: Optimizer, optional
|
|
:type lr_scheduler: _LRScheduler, optional
|
|
:type schedule: BaseSchedule, optional
|
|
"""
|
|
def __init__(self,
|
|
train_dataloader: Optional[DataLoader] = None,
|
|
test_dataloader: Optional[DataLoader] = None,
|
|
model: Module = None,
|
|
criterion: _Loss = None,
|
|
optimizer: Optimizer = None,
|
|
lr_scheduler: Optional[_LRScheduler] = None,
|
|
schedule: BaseSchedule = None,
|
|
gradient_accumulation: int = 1,
|
|
lr_scheduler_step: str = 'epoch'):
|
|
self.train_dataloader = train_dataloader
|
|
self.test_dataloader = test_dataloader
|
|
assert model is not None, "Engine requires a model"
|
|
self.model = model
|
|
self.criterion = criterion
|
|
self.optimizer = optimizer
|
|
self.lr_scheduler = lr_scheduler
|
|
self.schedule = schedule if schedule is not None \
|
|
else NoPipelineSchedule()
|
|
self.grad_accum_size = gradient_accumulation
|
|
self.grad_accum_step = 0
|
|
self.lr_step = 0 # for epoch updating
|
|
if lr_scheduler_step != 'epoch':
|
|
self.lr_step = 1
|
|
self._logger = get_global_dist_logger()
|
|
|
|
# build gradient handler
|
|
self._gradient_handlers = []
|
|
gradient_handler_cfg = []
|
|
|
|
if hasattr(gpc.config, 'gradient_handler'):
|
|
assert isinstance(gpc.config.gradient_handler, list), \
|
|
f'argument gradient_handler_cfg expected type list, ' \
|
|
f'but got type {type(gpc.config.gradient_handler)}'
|
|
gradient_handler_cfg = gpc.config.gradient_handler
|
|
elif isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
|
ZeroRedundancyOptimizer_Level_3)):
|
|
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
|
|
self._logger.info(
|
|
"Training with zero is detected, ZeROGradientHandler is automatically "
|
|
"added even though not specified in the configuration",
|
|
ranks=[0])
|
|
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
|
|
ParallelMode.DATA) > 1:
|
|
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
|
self._logger.info(
|
|
"Data parallel training is detected, DataParallelGradientHandler is automatically "
|
|
"added even though not specified in the configuration",
|
|
ranks=[0])
|
|
if len(gradient_handler_cfg) == 0:
|
|
self._logger.warning(
|
|
"No gradient handler is set up, please make sure you do not need "
|
|
"to all-reduce the gradients after a training step.",
|
|
ranks=[0])
|
|
for cfg in gradient_handler_cfg:
|
|
handler = build_gradient_handler(cfg, self.model, self.optimizer)
|
|
self._gradient_handlers.append(handler)
|
|
|
|
self.schedule.initialize(self.train_dataloader, self.model,
|
|
self.criterion, self.optimizer)
|
|
self.schedule.grad_accum = self.grad_accum_size
|
|
# add for robustness
|
|
if self.optimizer is None:
|
|
self.forward_only = True
|
|
else:
|
|
self.forward_only = False
|
|
|
|
def handle_gradient(self):
|
|
"""Handles all-reduce operations of gradients across different parallel groups.
|
|
"""
|
|
for handler in self._gradient_handlers:
|
|
handler.handle_gradient()
|
|
|
|
def set_dataloader(self, data: DataLoader, train: bool = True):
|
|
"""Sets dataloader in training or evaluation.
|
|
|
|
:param data: Dataloader to be set
|
|
:param train: Set training dataloader if True, otherwise evaluation dataloader
|
|
:type data: DataLoader
|
|
:type train: bool
|
|
"""
|
|
if train:
|
|
self.train_dataloader = data
|
|
else:
|
|
self.test_dataloader = data
|
|
|
|
def get_model(self):
|
|
"""Returns the neural network model in the engine.
|
|
"""
|
|
return self.model
|
|
|
|
def get_optimizer(self):
|
|
"""Returns optimizier in the engine.
|
|
"""
|
|
return self.optimizer
|
|
|
|
def get_lr_scheduler(self):
|
|
"""Returns the learning rate scheduler in the engine.
|
|
"""
|
|
return self.lr_scheduler
|
|
|
|
def train(self):
|
|
"""Sets the model to training mode.
|
|
"""
|
|
self.forward_only = False
|
|
self.schedule.train(dataloader=self.train_dataloader, mode=True)
|
|
|
|
def eval(self):
|
|
"""Sets the model to evaluation mode.
|
|
"""
|
|
self.forward_only = True
|
|
self.schedule.train(dataloader=self.test_dataloader, mode=False)
|
|
|
|
def is_train(self):
|
|
"""Returns True if it is in training, otherwise False.
|
|
"""
|
|
return not self.forward_only
|
|
|
|
def get_lr(self):
|
|
"""Gets current learning rate.
|
|
"""
|
|
if self.lr_scheduler is not None:
|
|
return self.lr_scheduler.get_lr()[0]
|
|
else:
|
|
return self.optimizer.param_groups[0]['lr']
|
|
|
|
def step(self, return_loss=True):
|
|
"""A running step based on the schedule. Usually, it runs a training or
|
|
evaluation over a batch of dataset.
|
|
|
|
:param return_loss: loss will be returned if True
|
|
:type return_loss: bool
|
|
:return: (output, lablel, loss)
|
|
"""
|
|
if not self.forward_only and self.grad_accum_step == 0:
|
|
self.schedule.zero_grad()
|
|
|
|
output, label, loss = self.schedule.forward_backward_step(
|
|
forward_only=self.forward_only, return_loss=return_loss)
|
|
|
|
if not self.forward_only:
|
|
self.grad_accum_step += 1
|
|
if self.grad_accum_step == self.grad_accum_size:
|
|
# all reduce gradients
|
|
self.handle_gradient()
|
|
self.schedule.step()
|
|
if self.lr_scheduler is not None and self.lr_step:
|
|
self.lr_scheduler.step()
|
|
self.grad_accum_step = 0
|
|
|
|
return output, label, loss
|
|
|
|
def complete(self):
|
|
"""Updating after a epoch.
|
|
"""
|
|
self.schedule.consume_batch()
|
|
if self.lr_scheduler is not None and self.lr_step == 0:
|
|
self.lr_scheduler.step()
|