#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from abc import ABC

from torch import Tensor


class BaseHook(ABC):
    """This class allows users to add desired actions in specific time points
    during training or evaluation.

    :param priority: Priority in the printing, hooks with small priority will be printed in front
    :type priority: int
    """

    def __init__(self, priority: int) -> None:
        self.priority = priority

    def after_hook_is_attached(self, trainer):
        """Actions after hooks are attached to trainer.
        """
        pass

    def before_train(self, trainer):
        """Actions before training.
        """
        pass

    def after_train(self, trainer):
        """Actions after training.
        """
        pass

    def before_train_iter(self, trainer):
        """Actions before running a training iteration.
        """
        pass

    def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
        """Actions after running a training iteration.

        Args:
           trainer (:class:`Trainer`): Trainer which is using this hook.
           output (:class:`torch.Tensor`): Output of the model.
           label (:class:`torch.Tensor`): Labels of the input data.
           loss (:class:`torch.Tensor`): Loss between the output and input data.
        """
        pass

    def before_train_epoch(self, trainer):
        """Actions before starting a training epoch.
        """
        pass

    def after_train_epoch(self, trainer):
        """Actions after finishing a training epoch.
        """
        pass

    def before_test(self, trainer):
        """Actions before evaluation.
        """
        pass

    def after_test(self, trainer):
        """Actions after evaluation.
        """
        pass

    def before_test_epoch(self, trainer):
        """Actions before starting a testing epoch.
        """
        pass

    def after_test_epoch(self, trainer):
        """Actions after finishing a testing epoch.
        """
        pass

    def before_test_iter(self, trainer):
        """Actions before running a testing iteration.
        """
        pass

    def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
        """Actions after running a testing iteration.

        Args:
           trainer (:class:`Trainer`): Trainer which is using this hook
           output (:class:`torch.Tensor`): Output of the model
           label (:class:`torch.Tensor`): Labels of the input data
           loss (:class:`torch.Tensor`): Loss between the output and input data
        """
        pass

    def init_runner_states(self, trainer, key, val):
        """Initializes trainer's state.

        Args:
            trainer (:class:`Trainer`): Trainer which is using this hook
            key: Key of state to be reset
            val: Value of state to be reset
        """
        if key not in trainer.states:
            trainer.states[key] = val