#!/usr/bin/env python # -*- encoding: utf-8 -*- from typing import Iterable import torch import inspect from ._base_schedule import BaseSchedule from colossalai.utils import conditional_context from typing import Callable class NonPipelineSchedule(BaseSchedule): """A helper schedule class for no pipeline parallelism running environment. During one process, it loads a batch of dataset and feeds it to the model. After getting the output and calculating the loss, it will use :meth:`step` to update the parameters if it is in training mode. Args: data_process_func (Callable, optional): The preprocessing function which receives a batch of data and returns a tuple in the form of (data, label). and it will be executed in load_batch. Example: # this shows an example of customized data_process_func def data_process_func(dataloader_output): item1, item2, item3 = dataloader_output data = (item1, item2) label = item3 return data, label """ def __init__(self, data_process_func: Callable = None): # check that non-pipeline schedule data process func only takes in one parameter # which is the batch data if data_process_func: sig = inspect.signature(data_process_func) assert len(sig.parameters) == 1, \ 'The data_process_func only takes in one parameter for NonPipelineSchedule, ' \ 'which is a tuple of tensors for the current batch, ' \ 'i.e. data_process_func(dataloader_output).' super().__init__(data_process_func) def forward_backward_step(self, engine, data_iter: Iterable, forward_only: bool = False, return_loss: bool = True, return_output_label: bool = True): """The process function that loads a batch of dataset and feeds it to the model. The returned labels and loss will None if :attr:`return_loss` is False. Args: engine (colossalai.engine.Engine): Colossalai engine for training and inference. data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). forward_only (bool, optional): If True, the model is run for the forward pass, else back propagation will be executed. return_loss (bool, optional): Loss will be returned if True. return_output_label (bool, optional): Output and label will be returned if True. Returns: Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. """ assert forward_only or return_loss, \ "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." batch_data = self.load_batch(data_iter) if self.data_process_func: data, label = self.data_process_func(batch_data) else: # if not batch data process func is given, # then we regard the batch data as a simple tuple of (data, label) data, label = batch_data # forward with conditional_context(torch.no_grad(), enable=forward_only): output = self._call_engine(engine, data) if return_loss: loss = self._call_engine_criterion(engine, output, label) if not forward_only: engine.backward(loss) if return_output_label: if return_loss: return output, label, loss else: return output, label, None else: if return_loss: return None, None, loss else: return None, None, None