from typing import Any, Callable, Iterable

from torch import Tensor
from torch.nn import Module

from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.stage_manager import PipelineStageManager


class PipelineSchedule:

    def __init__(self, stage_manager: PipelineStageManager) -> None:
        self.stage_manager = stage_manager

    def forward_backward_step(self,
                              model: Module,
                              optimizer: OptimizerWrapper,
                              data_iter: Iterable,
                              criterion: Callable[[Any, Any], Tensor],
                              return_loss: bool = False,
                              return_outputs: bool = False) -> dict:
        """Forward and backward step for pipeline training.

        Args:
            model (Module): Model to be trained.
            optimizer (OptimizerWrapper): Optimizer to be used.
            data_iter (Iterable): Data iterator.
            criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
            return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
            return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.

        Returns:
            dict: A dict with keys: 'loss' and 'outputs'.
        """
        raise NotImplementedError