mirror of https://github.com/hpcaitech/ColossalAI
66 lines
2.5 KiB
Python
66 lines
2.5 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from typing import Iterable
|
|
|
|
import torch
|
|
|
|
from colossalai.engine import Engine
|
|
from ._base_schedule import BaseSchedule
|
|
from colossalai.utils import conditional_context
|
|
|
|
|
|
class NonPipelineSchedule(BaseSchedule):
|
|
"""A helper schedule class for no pipeline parallelism running environment.
|
|
During one process, it loads a batch of dataset and feeds it to the model.
|
|
After getting the output and calculating the loss, it will use :meth:`step`
|
|
to update the parameters if it is in training mode.
|
|
"""
|
|
|
|
def forward_backward_step(self,
|
|
engine: Engine,
|
|
data_iter: Iterable,
|
|
forward_only: bool = False,
|
|
return_loss: bool = True,
|
|
return_output_label: bool = True):
|
|
"""The process function that loads loads a batch of dataset and feeds it to the model.
|
|
The returned labels and loss will None if :attr:`return_loss` is False.
|
|
|
|
:param engine: Model for training and inference
|
|
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
|
|
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
|
|
:param return_loss: Loss will be returned if True
|
|
:param return_output_label: Output and label will be returned if True
|
|
:type engine: Iterator
|
|
:type data_iter: Iterator
|
|
:type forward_only: bool, optional
|
|
:type return_loss: bool, optional
|
|
:type return_output_label: bool, optional
|
|
|
|
:return: (output, label, loss)
|
|
:rtype: Tuple[:class:`torch.Tensor`]
|
|
"""
|
|
assert forward_only or return_loss, \
|
|
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
|
data, label = self.load_batch(data_iter)
|
|
|
|
# forward
|
|
with conditional_context(torch.no_grad(), enable=forward_only):
|
|
output = self._call_engine(engine, data)
|
|
if return_loss:
|
|
loss = self._call_engine_criterion(engine, output, label)
|
|
|
|
if not forward_only:
|
|
engine.backward(loss)
|
|
|
|
if return_output_label:
|
|
if return_loss:
|
|
return output, label, loss
|
|
else:
|
|
return output, label, None
|
|
else:
|
|
if return_loss:
|
|
return None, None, loss
|
|
else:
|
|
return None, None, None
|