mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
171 lines
6.5 KiB
171 lines
6.5 KiB
3 years ago
|
#!/usr/bin/env python
|
||
|
# -*- encoding: utf-8 -*-
|
||
|
|
||
|
from typing import Optional
|
||
|
|
||
|
from colossalai.builder import build_gradient_handler
|
||
|
from colossalai.context import ParallelMode
|
||
|
from colossalai.core import global_context as gpc
|
||
|
from colossalai.logging import get_global_dist_logger
|
||
|
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||
|
ZeroRedundancyOptimizer_Level_3)
|
||
|
from torch.nn import Module
|
||
|
from torch.nn.modules.loss import _Loss
|
||
|
from torch.optim import Optimizer
|
||
|
from torch.optim.lr_scheduler import _LRScheduler
|
||
|
from torch.utils.data import DataLoader
|
||
|
|
||
|
from .schedule import BaseSchedule, NoPipelineSchedule
|
||
|
|
||
|
|
||
|
class Engine:
|
||
|
"""Basic engine class for training and evaluation. It runs a specific process method
|
||
|
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
|
||
|
|
||
|
:param train_dataloader: Dataloader in training
|
||
|
:param test_dataloader: Dataloader in evaluation
|
||
|
:param model: The neural network model
|
||
|
:param criterion: Criterion for calculating loss
|
||
|
:param optimizer: Optimizer for updating the parameters
|
||
|
:param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation
|
||
|
:param schedule: Running schedule in :meth:`step`
|
||
|
:type train_dataloader: DataLoader, optional
|
||
|
:type test_dataloader: DataLoader, optional
|
||
|
:type model: Module
|
||
|
:type criterion: _Loss, optional
|
||
|
:type optimizer: Optimizer, optional
|
||
|
:type lr_scheduler: _LRScheduler, optional
|
||
|
:type schedule: BaseSchedule, optional
|
||
|
"""
|
||
|
def __init__(self,
|
||
|
train_dataloader: Optional[DataLoader] = None,
|
||
|
test_dataloader: Optional[DataLoader] = None,
|
||
|
model: Module = None,
|
||
|
criterion: _Loss = None,
|
||
|
optimizer: Optimizer = None,
|
||
|
lr_scheduler: Optional[_LRScheduler] = None,
|
||
|
schedule: BaseSchedule = None):
|
||
|
self.train_dataloader = train_dataloader
|
||
|
self.test_dataloader = test_dataloader
|
||
|
assert model is not None, "Engine requires a model"
|
||
|
self.model = model
|
||
|
self.criterion = criterion
|
||
|
self.optimizer = optimizer
|
||
|
self.lr_scheduler = lr_scheduler
|
||
|
self.schedule = schedule if schedule is not None \
|
||
|
else NoPipelineSchedule()
|
||
|
self._logger = get_global_dist_logger()
|
||
|
|
||
|
# build gradient handler
|
||
|
self._gradient_handlers = []
|
||
|
gradient_handler_cfg = []
|
||
|
|
||
|
if hasattr(gpc.config, 'gradient_handler'):
|
||
|
assert isinstance(gpc.config.gradient_handler, list), \
|
||
|
f'argument gradient_handler_cfg expected type list, ' \
|
||
|
f'but got type {type(gpc.config.gradient_handler)}'
|
||
|
gradient_handler_cfg = gpc.config.gradient_handler
|
||
|
elif isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||
|
ZeroRedundancyOptimizer_Level_3)):
|
||
|
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
|
||
|
self._logger.info(
|
||
|
"Training with zero is detected, ZeROGradientHandler is automatically "
|
||
|
"added even though not specified in the configuration",
|
||
|
ranks=[0])
|
||
|
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
|
||
|
ParallelMode.DATA) > 1:
|
||
|
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
||
|
self._logger.info(
|
||
|
"Data parallel training is detected, DataParallelGradientHandler is automatically "
|
||
|
"added even though not specified in the configuration",
|
||
|
ranks=[0])
|
||
|
if len(gradient_handler_cfg) == 0:
|
||
|
self._logger.warning(
|
||
|
"No gradient handler is set up, please make sure you do not need "
|
||
|
"to all-reduce the gradients after a training step.",
|
||
|
ranks=[0])
|
||
|
for cfg in gradient_handler_cfg:
|
||
|
handler = build_gradient_handler(cfg, self.model, self.optimizer)
|
||
|
self._gradient_handlers.append(handler)
|
||
|
|
||
|
self.schedule.initialize(self.train_dataloader, self.model,
|
||
|
self.criterion, self.optimizer,
|
||
|
self.lr_scheduler)
|
||
|
self.forward_only = False
|
||
|
|
||
|
def handle_gradient(self):
|
||
|
"""Handles all-reduce operations of gradients across different parallel groups.
|
||
|
"""
|
||
|
for handler in self._gradient_handlers:
|
||
|
handler.handle_gradient()
|
||
|
|
||
|
def set_dataloader(self, data: DataLoader, train: bool = True):
|
||
|
"""Sets dataloader in training or evaluation.
|
||
|
|
||
|
:param data: Dataloader to be set
|
||
|
:param train: Set training dataloader if True, otherwise evaluation dataloader
|
||
|
:type data: DataLoader
|
||
|
:type train: bool
|
||
|
"""
|
||
|
if train:
|
||
|
self.train_dataloader = data
|
||
|
else:
|
||
|
self.test_dataloader = data
|
||
|
|
||
|
def get_model(self):
|
||
|
"""Returns the neural network model in the engine.
|
||
|
"""
|
||
|
return self.model
|
||
|
def get_optimizer(self):
|
||
|
"""Returns optimizier in the engine.
|
||
|
"""
|
||
|
return self.optimizer
|
||
|
|
||
|
def get_lr_scheduler(self):
|
||
|
"""Returns the learning rate scheduler in the engine.
|
||
|
"""
|
||
|
return self.lr_scheduler
|
||
|
|
||
|
def train(self):
|
||
|
"""Sets the model to training mode.
|
||
|
"""
|
||
|
self.forward_only = False
|
||
|
self.schedule.train(dataloader=self.train_dataloader, mode=True)
|
||
|
|
||
|
def eval(self):
|
||
|
"""Sets the model to evaluation mode.
|
||
|
"""
|
||
|
self.forward_only = True
|
||
|
self.schedule.train(dataloader=self.test_dataloader, mode=False)
|
||
|
|
||
|
def is_train(self):
|
||
|
"""Returns True if it is in training, otherwise False.
|
||
|
"""
|
||
|
return not self.forward_only
|
||
|
|
||
|
def get_lr(self):
|
||
|
"""Gets current learning rate.
|
||
|
"""
|
||
|
return self.schedule.get_lr()
|
||
|
|
||
|
def step(self, return_loss=True):
|
||
|
"""A running step based on the schedule. Usually, it runs a training or
|
||
|
evaluation over a batch of dataset.
|
||
|
|
||
|
:param return_loss: loss will be returned if True
|
||
|
:type return_loss: bool
|
||
|
:return: (output, lablel, loss)
|
||
|
"""
|
||
|
self.schedule.zero_grad(forward_only=self.forward_only)
|
||
|
|
||
|
output, label, loss = self.schedule.forward_backward_step(
|
||
|
forward_only=self.forward_only, return_loss=return_loss)
|
||
|
|
||
|
if not self.forward_only:
|
||
|
# all reduce gradients
|
||
|
self.handle_gradient()
|
||
|
|
||
|
self.schedule.step()
|
||
|
|
||
|
return output, label, loss
|