You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/engine/schedule/_base_schedule.py

130 lines
4.2 KiB

3 years ago
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
import torch
from colossalai.logging import get_global_dist_logger
from colossalai.utils import get_current_device
class BaseSchedule(ABC):
"""A basic helper class to control the process of training or evaluation.
"""
def __init__(self):
self.initialized = False
self.logger = get_global_dist_logger()
@property
@abstractmethod
def num_steps(self):
"""The number of batches in training or evaluation.
"""
pass
def initialize(self,
dataloader=None,
model=None,
criterion=None,
optimizer=None,
lr_scheduler=None):
"""Initializes the schedule and set parameters before running.
:param dataloader: DataLoader in training or evaluation
:param model: The neural network model
:param criterion: Criterion for calculating loss
:param optimizer: Optimizer for updating the parameters
:param lr_scheduler: Learning rate scheduler in the process
"""
self.dataloader = dataloader
assert model is not None, "Schedule requires a model"
self.model = model
assert criterion is not None, "Schedule requires a criterion"
self.criterion = criterion
assert optimizer is not None, "Schedule requires an optimizer"
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.initialized = True
def check_initialized(self):
"""Checks whether the schedule is initialized.
"""
assert self.initialized, \
'Schedule is not initialized. Call schedule.initialize(...) before using it.'
def load_batch(self):
"""Loads a batch of dataset. It returns the data and labels which are
already in the same GPU as where the model's.
:return: (data, label)
:rtype: (Tensor, Tensor)
"""
self.check_initialized()
if self.data_iter is None:
raise RuntimeError('Dataloader is not defined.')
data, label = next(self.data_iter)
return self._move_to_device(data), self._move_to_device(label)
def _move_to_device(self, data):
if isinstance(data, (
tuple,
list,
)):
data = tuple([
d.to(get_current_device()).detach() for d in data
if torch.is_tensor(d)
])
elif torch.is_tensor(data):
data = data.to(get_current_device()).detach()
return data
def train(self, dataloader=None, mode=True):
"""Sets the dataloader to be used and turn the model to
training or evaluation mode.
:param dataloader: Dataloader to be used
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode.
"""
self.check_initialized()
if mode:
self.model.train()
else:
self.model.eval()
if dataloader is not None:
self.dataloader = dataloader
self.data_iter = iter(dataloader)
def zero_grad(self, forward_only=False):
"""Cleans gradients with the optimizer.
"""
if not forward_only:
self.check_initialized()
self.optimizer.zero_grad()
def get_lr(self):
"""Returns the current learning rate.
"""
if self.lr_scheduler is not None:
return self.lr_scheduler.get_lr()[0]
else:
return self.optimizer.param_groups[0]['lr']
def step(self):
"""Updates the parameters and learning rate with the optimizer.
"""
self.check_initialized()
self.optimizer.step()
# update lr scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()
@abstractmethod
def forward_backward_step(self, forward_only=False, return_loss=True):
"""The process function over a batch of dataset for training or evaluation.
:param forward_only: If True, the process won't include backward.
:param return_loss: If False, the loss won't be returned.
"""
pass