#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from typing import List, Iterable
from torch.nn import Module
from torch.nn.modules.loss import _Loss

from colossalai.logging import get_dist_logger
from torch import Tensor
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
from typing import Optional, Type
from colossalai.engine.gradient_handler import BaseGradientHandler
from colossalai.logging import get_dist_logger


class Engine:
    """Basic engine class for training and evaluation. It runs a specific process method
    :meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
    It controls a iteration in training.

    Args:
        model (``torch.nn.Module``): The neural network model.
        optimizer (``colossalai.nn.optimizer.ColossalaiOptimizer``): Optimizer for updating the parameters.
        criterion (``torch.nn.modules.loss._Loss``, optional): Loss function for calculating loss.
        gradient_handlers (List[``BaseGradientHandler``], optional): A list of gradient handler used in backward.
        clip_grad_norm (float, optional): The norm of gradient clipping.
        ophook_list (list): List of ophook.
        verbose (bool): whether to display log info.
        schedule (''BaseSchedule''): Runtime schedule.

    Examples:
        >>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
        >>> model = ...
        >>> criterion = ...
        >>> optimizer = ...
        >>> train_dataloader = ...
        >>> engine, _, _, _ = colossalai.initialize(model, optimizer, criterion)
        >>> engine.train()
        >>> for inputs, labels in train_dataloader
        >>>     # set gradients to zero
        >>>     engine.zero_grad()
        >>>     # run forward pass
        >>>     outputs = engine(inputs)
        >>>     # compute loss value and run backward pass
        >>>     loss = engine.criterion(outputs, labels)
        >>>     engine.backward(loss)
        >>>     # update parameters
        >>>     engine.step()

    The example of using Engine in training could be find in
    `Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_. and
    `Run resnet cifar10 with engine <https://github.com/hpcaitech/ColossalAI-Examples/blob/main/image/resnet/run_resnet_cifar10_with_engine.py>`_.
    """

    def __init__(self,
                 model: Module,
                 optimizer: "ColossalaiOptimizer",
                 criterion: Optional[_Loss] = None,
                 gradient_handlers: Optional[List[BaseGradientHandler]] = None,
                 clip_grad_norm: float = 0.0,
                 ophook_list: Optional[List[BaseOpHook]] = None,
                 verbose: bool = True,
                 schedule: Optional[BaseSchedule] = None):
        self._model = model
        self._optimizer = optimizer
        self._criterion = criterion
        self._clip_grad_norm = clip_grad_norm
        self._verbose = verbose
        self._logger = get_dist_logger()

        # state
        self.training = True    # default

        # build gradient handler
        if gradient_handlers:
            self._gradient_handlers = gradient_handlers
        else:
            self._gradient_handlers = []

        if ophook_list is None:
            self._ophook_list = []
        else:
            self._ophook_list = ophook_list

        # build schedule
        if schedule:
            assert isinstance(schedule, BaseSchedule), \
                f'expected schedule to be of type BaseSchedule, but got {type(schedule)}'
            self._schedule = schedule
        else:
            self._schedule = NonPipelineSchedule()
        if self.uses_pipeline:
            self._schedule.pre_processing(self)

        #register hook if any
        if len(self._ophook_list) > 0:
            register_ophooks_recursively(self._model, self._ophook_list)

    @property
    def ophooks(self):
        """show current activated ophooks"""
        return self._ophook_list

    @property
    def model(self):
        """Model attached to the engine"""
        return self._model

    @property
    def optimizer(self):
        """Optimizer attached to the engine"""
        return self._optimizer

    @property
    def criterion(self):
        """Criterion attached to the engine"""
        return self._criterion

    @property
    def schedule(self):
        """Schedule attached to the engine"""
        return self._schedule

    @property
    def uses_pipeline(self):
        """show the pipeline parallel used or not"""
        return isinstance(self._schedule, (PipelineSchedule, InterleavedPipelineSchedule))

    def add_hook(self, ophook: Type[BaseOpHook]) -> None:
        """add necessary hook"""
        # whether this hook exist
        for h in self._ophook_list:
            if type(h) == type(ophook):
                logger = get_dist_logger()
                logger.warning(f"duplicate hooks, at least two instance of {type(ophook)}")
        self._ophook_list.append(ophook)
        register_ophooks_recursively(self._model, self._ophook_list)

    def remove_hook(self, ophook: Type[BaseOpHook]) -> None:
        """remove hook"""
        logger = get_dist_logger()
        logger.warning(f"removing hooks is currently not supported")

    def zero_grad(self):
        """Set the gradient of parameters to zero
        """
        self.optimizer.zero_grad()

    def step(self):
        """Execute parameter update
        """
        self._all_reduce_gradients()
        self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
        return self.optimizer.step()

    def backward(self, loss: Tensor):
        """Start backward propagation given the loss value computed by a loss function.

        Args:
            loss (:class:`torch.Tensor`): Loss value computed by a loss function.
        """
        ret = self.optimizer.backward(loss)
        for ophook in self._ophook_list:
            ophook.post_iter()
        return ret

    def backward_by_grad(self, tensor, grad):
        """Start backward propagation given the gradient of the output tensor.

        Args:
            tensor (:class:`torch.Tensor`): Output tensor.
            grad (:class:`torch.Tensor`): Gradient passed back to the output.
        """
        ret = self.optimizer.backward_by_grad(tensor, grad)
        for ophook in self._ophook_list:
            ophook.post_iter()
        return ret

    def __call__(self, *args, **kwargs):
        """Run the forward step for the model.

        Returns:
            Tuple[:class:`torch.Tensor`] or :class:`torch.Tensor`: Output of the model.
        """
        return self.model(*args, **kwargs)

    def _all_reduce_gradients(self):
        """Handles all-reduce operations of gradients across different parallel groups.
        """
        for handler in self._gradient_handlers:
            handler.handle_gradient()

    def execute_schedule(self, data_iter: Iterable, **kwargs):
        """Run the forward, loss computation, and backward for the model.
        Returns a tuple of (output, label, loss).

        Returns:
            Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
        """
        output, label, loss = self._schedule.forward_backward_step(self, data_iter, **kwargs)
        return output, label, loss

    def train(self):
        """Sets the model to training mode.
        """
        self.training = True
        self._model.train()

    def eval(self):
        """Sets the model to evaluation mode.
        """
        self.training = False
        self._model.eval()