ColossalAI/colossalai/engine/schedule/_non_pipeline_schedule.py

64 lines
2.4 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Iterable
import torch
import torch.nn as nn
from colossalai.engine import Engine
from torch.optim import Optimizer
from ._base_schedule import BaseSchedule
from colossalai.utils import conditional_context
class NonPipelineSchedule(BaseSchedule):
"""A helper schedule class for no pipeline parallelism running environment.
During one process, it loads a batch of dataset and feeds it to the model.
After getting the output and calculating the loss, it will use :meth:`step`
to update the parameters if it is in training mode.
:param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision
:type amp_type: AMP_TYPE
:type amp_config: dict
"""
def forward_backward_step(self,
engine: Engine,
data_iter: Iterable,
forward_only: bool = False,
return_loss: bool = True):
"""The process function that loads loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
:param engine: Model for training and inference
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
:param return_loss: Loss will be returned if True
:type engine: Iterator
:type data_iter: Iterator
:type forward_only: bool, optional
:type return_loss: bool, optional
:return: (output, label, loss)
:rtype: Tuple[:class:`torch.Tensor`]
"""
assert forward_only or return_loss, \
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
data, label = self.load_batch(data_iter)
# forward
with conditional_context(torch.no_grad(), enable=forward_only):
output = engine(*data)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = engine.criterion(*output, *label)
if not forward_only:
engine.backward(loss)
if return_loss:
return output, label, loss
else:
return output, None, None