#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from typing import Iterable

import torch

from colossalai.engine import Engine
from ._base_schedule import BaseSchedule
from colossalai.utils import conditional_context


class NonPipelineSchedule(BaseSchedule):
    """A helper schedule class for no pipeline parallelism running environment.
    During one process, it loads a batch of dataset and feeds it to the model.
    After getting the output and calculating the loss, it will use :meth:`step`
    to update the parameters if it is in training mode.
    """

    def forward_backward_step(self,
                              engine: Engine,
                              data_iter: Iterable,
                              forward_only: bool = False,
                              return_loss: bool = True,
                              return_output_label: bool = True):
        """The process function that loads loads a batch of dataset and feeds it to the model.
        The returned labels and loss will None if :attr:`return_loss` is False.

        :param engine: Model for training and inference
        :param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
        :param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
        :param return_loss: Loss will be returned if True
        :param return_output_label: Output and label will be returned if True
        :type engine: Iterator
        :type data_iter: Iterator
        :type forward_only: bool, optional
        :type return_loss: bool, optional
        :type return_output_label: bool, optional

        :return: (output, label, loss)
        :rtype: Tuple[:class:`torch.Tensor`]
        """
        assert forward_only or return_loss, \
            "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
        data, label = self.load_batch(data_iter)

        # forward
        with conditional_context(torch.no_grad(), enable=forward_only):
            output = self._call_engine(engine, data)
            if return_loss:
                loss = self._call_engine_criterion(engine, output, label)

        if not forward_only:
            engine.backward(loss)

        if return_output_label:
            if return_loss:
                return output, label, loss
            else:
                return output, label, None
        else:
            if return_loss:
                return None, None, loss
            else:
                return None, None, None