from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, List, Tuple

import torch
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten


class ColoParamOpHook(ABC):
    """
    Hook which is triggered by each operation when operands contain ColoParameter.
    To customize it, you must inherit this abstract class, and implement ``pre_forward``,
    ``post_forward``, ``pre_backward`` and ``post_backward``.
    These four methods apply a list of ColoParameter as input args.
    """

    @abstractmethod
    def pre_forward(self, params: List[torch.Tensor]) -> None:
        pass

    @abstractmethod
    def post_forward(self, params: List[torch.Tensor]) -> None:
        pass

    @abstractmethod
    def pre_backward(self, params: List[torch.Tensor]) -> None:
        pass

    @abstractmethod
    def post_backward(self, params: List[torch.Tensor]) -> None:
        pass


class ColoParamOpHookManager:
    """
    Manage your param op hooks. It only has static methods.
    The only static method you should call is ``use_hooks(*hooks)``.
    """

    hooks: Tuple[ColoParamOpHook, ...] = tuple()

    @staticmethod
    @contextmanager
    def use_hooks(*hooks: ColoParamOpHook):
        """Change the param op hooks you use. Nested calling is allowed.

        Example:
            >>> with ColoParamOpHookManager.use_hooks(*hooks):
            >>>     do_something()
            >>>     with ColoParamOpHookManager.use_hooks():
            >>>         // clear hooks
            >>>         do_something()
        """
        try:
            old_param_op_hooks = ColoParamOpHookManager.hooks
            ColoParamOpHookManager.hooks = hooks
            yield
        finally:
            ColoParamOpHookManager.hooks = old_param_op_hooks

    @staticmethod
    def _trigger_pre_forward(params: List[torch.Tensor]) -> None:
        for hook in ColoParamOpHookManager.hooks:
            hook.pre_forward(params)

    @staticmethod
    def _trigger_post_forward(params: List[torch.Tensor]) -> None:
        for hook in ColoParamOpHookManager.hooks:
            hook.post_forward(params)

    @staticmethod
    def _trigger_pre_backward(params: List[torch.Tensor]) -> None:
        for hook in ColoParamOpHookManager.hooks:
            hook.pre_backward(params)

    @staticmethod
    def _trigger_post_backward(params: List[torch.Tensor]) -> None:
        for hook in ColoParamOpHookManager.hooks:
            hook.post_backward(params)

    @staticmethod
    def pre_op(params: List[torch.Tensor], *args: Any) -> list:
        ColoParamOpHookManager._trigger_pre_forward(params)
        # auto grad function can only recognize torch.Tensor, thus we have to flatten the input
        # if one of the input requires grad, all the output will be treated as requires grad
        # and will have grad fn even the corresponding input does not require grad
        # we have to extract tensors requiring grad into flat list and then merge them back
        grad_args, other_args, grad_flags, spec = _flatten_grad_args(args)
        new_grad_args = PreFwdPostBwd.apply(params, *grad_args)
        return _merge_args(new_grad_args, other_args, grad_flags, spec)

    @staticmethod
    def post_op(params: List[torch.Tensor], arg: Any) -> Any:
        ColoParamOpHookManager._trigger_post_forward(params)
        return PostFwdPreBwd.apply(params, arg)

    @staticmethod
    def has_hook() -> bool:
        return len(ColoParamOpHookManager.hooks) > 0


class PreFwdPostBwd(torch.autograd.Function):
    @staticmethod
    def forward(ctx, params, *args):
        ctx.params = params
        return args

    @staticmethod
    def backward(ctx, *grads):
        ColoParamOpHookManager._trigger_post_backward(ctx.params)
        return (None,) + grads


class PostFwdPreBwd(torch.autograd.Function):
    @staticmethod
    def forward(ctx, params, args):
        ctx.params = params
        return args

    @staticmethod
    def backward(ctx, *grads):
        ColoParamOpHookManager._trigger_pre_backward(ctx.params)
        return (None,) + grads


def _is_grad_tensor(obj) -> bool:
    if torch.is_tensor(obj):
        if obj.grad_fn is not None or obj.requires_grad:
            return True
    return False


def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]:
    flat_args, spec = tree_flatten(args)
    grad_args = []
    other_args = []
    grad_flags = []
    for arg in flat_args:
        flag = _is_grad_tensor(arg)
        grad_flags.append(flag)
        if flag:
            grad_args.append(arg)
        else:
            other_args.append(arg)
    assert len(grad_args) > 0
    return grad_args, other_args, grad_flags, spec


def _merge_args(grad_args, other_args, grad_flags, spec):
    grad_iter = iter(grad_args)
    other_iter = iter(other_args)
    flat_args = [next(grad_iter) if flag else next(other_iter) for flag in grad_flags]
    return tree_unflatten(flat_args, spec)