mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
130 lines
4.2 KiB
130 lines
4.2 KiB
3 years ago
|
#!/usr/bin/env python
|
||
|
# -*- encoding: utf-8 -*-
|
||
|
|
||
|
from abc import ABC, abstractmethod
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from colossalai.logging import get_global_dist_logger
|
||
|
from colossalai.utils import get_current_device
|
||
|
|
||
|
|
||
|
class BaseSchedule(ABC):
|
||
|
"""A basic helper class to control the process of training or evaluation.
|
||
|
"""
|
||
|
def __init__(self):
|
||
|
self.initialized = False
|
||
|
self.logger = get_global_dist_logger()
|
||
|
|
||
|
@property
|
||
|
@abstractmethod
|
||
|
def num_steps(self):
|
||
|
"""The number of batches in training or evaluation.
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
def initialize(self,
|
||
|
dataloader=None,
|
||
|
model=None,
|
||
|
criterion=None,
|
||
|
optimizer=None,
|
||
|
lr_scheduler=None):
|
||
|
"""Initializes the schedule and set parameters before running.
|
||
|
|
||
|
:param dataloader: DataLoader in training or evaluation
|
||
|
:param model: The neural network model
|
||
|
:param criterion: Criterion for calculating loss
|
||
|
:param optimizer: Optimizer for updating the parameters
|
||
|
:param lr_scheduler: Learning rate scheduler in the process
|
||
|
"""
|
||
|
self.dataloader = dataloader
|
||
|
assert model is not None, "Schedule requires a model"
|
||
|
self.model = model
|
||
|
assert criterion is not None, "Schedule requires a criterion"
|
||
|
self.criterion = criterion
|
||
|
assert optimizer is not None, "Schedule requires an optimizer"
|
||
|
self.optimizer = optimizer
|
||
|
self.lr_scheduler = lr_scheduler
|
||
|
self.initialized = True
|
||
|
|
||
|
def check_initialized(self):
|
||
|
"""Checks whether the schedule is initialized.
|
||
|
"""
|
||
|
assert self.initialized, \
|
||
|
'Schedule is not initialized. Call schedule.initialize(...) before using it.'
|
||
|
|
||
|
def load_batch(self):
|
||
|
"""Loads a batch of dataset. It returns the data and labels which are
|
||
|
already in the same GPU as where the model's.
|
||
|
|
||
|
:return: (data, label)
|
||
|
:rtype: (Tensor, Tensor)
|
||
|
"""
|
||
|
self.check_initialized()
|
||
|
if self.data_iter is None:
|
||
|
raise RuntimeError('Dataloader is not defined.')
|
||
|
data, label = next(self.data_iter)
|
||
|
return self._move_to_device(data), self._move_to_device(label)
|
||
|
|
||
|
def _move_to_device(self, data):
|
||
|
if isinstance(data, (
|
||
|
tuple,
|
||
|
list,
|
||
|
)):
|
||
|
data = tuple([
|
||
|
d.to(get_current_device()).detach() for d in data
|
||
|
if torch.is_tensor(d)
|
||
|
])
|
||
|
elif torch.is_tensor(data):
|
||
|
data = data.to(get_current_device()).detach()
|
||
|
return data
|
||
|
|
||
|
def train(self, dataloader=None, mode=True):
|
||
|
"""Sets the dataloader to be used and turn the model to
|
||
|
training or evaluation mode.
|
||
|
|
||
|
:param dataloader: Dataloader to be used
|
||
|
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode.
|
||
|
"""
|
||
|
self.check_initialized()
|
||
|
if mode:
|
||
|
self.model.train()
|
||
|
else:
|
||
|
self.model.eval()
|
||
|
if dataloader is not None:
|
||
|
self.dataloader = dataloader
|
||
|
self.data_iter = iter(dataloader)
|
||
|
|
||
|
def zero_grad(self, forward_only=False):
|
||
|
"""Cleans gradients with the optimizer.
|
||
|
"""
|
||
|
if not forward_only:
|
||
|
self.check_initialized()
|
||
|
self.optimizer.zero_grad()
|
||
|
|
||
|
def get_lr(self):
|
||
|
"""Returns the current learning rate.
|
||
|
"""
|
||
|
if self.lr_scheduler is not None:
|
||
|
return self.lr_scheduler.get_lr()[0]
|
||
|
else:
|
||
|
return self.optimizer.param_groups[0]['lr']
|
||
|
|
||
|
def step(self):
|
||
|
"""Updates the parameters and learning rate with the optimizer.
|
||
|
"""
|
||
|
self.check_initialized()
|
||
|
self.optimizer.step()
|
||
|
# update lr scheduler
|
||
|
if self.lr_scheduler is not None:
|
||
|
self.lr_scheduler.step()
|
||
|
|
||
|
@abstractmethod
|
||
|
def forward_backward_step(self, forward_only=False, return_loss=True):
|
||
|
"""The process function over a batch of dataset for training or evaluation.
|
||
|
|
||
|
:param forward_only: If True, the process won't include backward.
|
||
|
:param return_loss: If False, the loss won't be returned.
|
||
|
"""
|
||
|
pass
|