ColossalAI/colossalai/engine/schedule/_non_pipeline_schedule.py

66 lines
2.6 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Iterable
import torch
from ._base_schedule import BaseSchedule
from colossalai.utils import conditional_context
class NonPipelineSchedule(BaseSchedule):
"""A helper schedule class for no pipeline parallelism running environment.
During one process, it loads a batch of dataset and feeds it to the model.
After getting the output and calculating the loss, it will use :meth:`step`
to update the parameters if it is in training mode.
Args:
batch_data_process_func (Callable, optional): The preprocessing function which receives a batch of data,
and it will be executed in load_batch.
"""
def forward_backward_step(self,
engine,
data_iter: Iterable,
forward_only: bool = False,
return_loss: bool = True,
return_output_label: bool = True):
"""The process function that loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional):
If True, the model is run for the forward pass, else back propagation will be executed.
return_loss (bool, optional): Loss will be returned if True.
return_output_label (bool, optional): Output and label will be returned if True.
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
"""
assert forward_only or return_loss, \
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
data, label = self.load_batch(data_iter)
# forward
with conditional_context(torch.no_grad(), enable=forward_only):
output = self._call_engine(engine, data)
if return_loss:
loss = self._call_engine_criterion(engine, output, label)
if not forward_only:
engine.backward(loss)
if return_output_label:
if return_loss:
return output, label, loss
else:
return output, label, None
else:
if return_loss:
return None, None, loss
else:
return None, None, None