from copy import deepcopy
from typing import Any, Dict, List, Tuple

from torch import Tensor
from torch.fx import Graph, Node

from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.profiler import (
    activation_size,
    calculate_bwd_time,
    calculate_fwd_out,
    calculate_fwd_time,
    calculate_fwd_tmp,
)
from colossalai.logging import get_dist_logger

from .ckpt_solver_base import CheckpointSolverBase
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence

__all__ = ['CheckpointSolverRotor']


class CheckpointSolverRotor(CheckpointSolverBase):

    def __init__(self,
                 graph: Graph,
                 free_memory: float = -1,
                 cnode: List[str] = None,
                 memory_slots: int = 500,
                 optim_multiplier: float = 1.0):
        """This is the simple implementation of dynamic programming algorithm rotor
        in https://hal.inria.fr/hal-02352969. Some code are adapted from
        https://gitlab.inria.fr/hiepacs/rotor.

        Usage:
            Assume that we have a ``GraphModule``, and we have already done the extractions
            to the graph to retrieve all information needed, then we could use the following
            code to find a solution using ``CheckpointSolverRotor``:
            >>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0])
            >>> rotor_graph = solver.solve(force_python=True)   # otherwise use C solver
            >>> gm.graph = rotor_graph    # set the graph to a new graph

        Args:
            graph (Graph): The computing graph to be optimized.
            free_memory (float, optional): Memory constraint for the solution, unit is byte.
                Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1.
            cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
            memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
            optim_multiplier (float, optional): The multiplier of extra weight storage for the
            ``torch.optim.Optimizer``. Default to 1.0.
        """
        super().__init__(graph, free_memory, True, cnode, optim_multiplier)
        self.memory_slots = memory_slots

        # construct chain
        unit = self.free_memory // self.memory_slots
        self.chain = self._construct_chain(self.graph, self.node_list)
        self.chain.discretize_all(unit)

        self.cost_table = None
        self.back_ptr = None
        self.sequence = None

    def solve(self, force_python: bool = False, verbose: bool = False) -> Graph:
        """Solve the checkpointing problem using rotor algorithm.

        Args:
            force_python (bool, optional): Use Python version of solver, else use C version. Defaults to False.
            verbose (bool, optional): Print verbose information. Defaults to False.

        Returns:
            graph (Graph): The optimized graph, should be a copy of the original graph.
        """
        chain = self.chain

        # compute cost table
        if force_python:
            self.cost_table, self.back_ptr = self._compute_table(chain, self.memory_slots)
        else:
            self.cost_table, self.back_ptr = self._compute_table_c(chain, self.memory_slots)

        if verbose:
            self.print_chain()

        # backtrack
        try:
            self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table,
                                            self.back_ptr)
            self._annotate_from_sequence(self.sequence, self.node_list)
        except ValueError as e:
            # using logger to annonce that the solver is failed
            logger = get_dist_logger()
            logger.warning(f'Checkpoint solver failed: {e}')
            raise ValueError

        if verbose:
            self.print_sequence()

        return deepcopy(self.graph)

    def print_chain(self):
        print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
        for idx in range(len(self.node_list) - 1):
            print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx],
                  self.chain.btmp[idx])
        print(f'Chain = {self.chain}')

    def print_sequence(self):
        print(f'Sequence = {self.sequence}')

    @classmethod
    def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain:
        input_tensors = cls._extract_input(graph)
        ftime, btime, ftmp, btmp = list(), list(), list(), list()
        xbar, x = [activation_size(input_tensors)], [activation_size(input_tensors)]

        for node in node_list:
            node_info = cls._extract_node_info(node)
            ftime.append(node_info[0])
            btime.append(node_info[1])
            x.append(node_info[2])
            xbar.append(node_info[3])
            ftmp.append(node_info[4])
            btmp.append(node_info[5])

        # currently we view loss backward temp as zero
        btime.append(0)
        btmp.append(0)

        return Chain(ftime, btime, x, xbar, ftmp, btmp)

    @classmethod
    def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]:
        """Extract node info from a list of nodes"""
        xbar = 0
        ftime = 0
        btime = 0
        fwd_mem_peak = 0
        for n in node:
            assert isinstance(n, Node), f'{n} is not a Node'
            if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
                # in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
                xbar += n.meta['fwd_mem_out']
                fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'])
            else:
                xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
                fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))

            # minimum flop count is required
            ftime += max(calculate_fwd_time(n), 1.0)
            btime += max(calculate_bwd_time(n), 1.0)

        x = calculate_fwd_out(node[-1])
        xbar = max(x, xbar)
        ftmp = fwd_mem_peak - xbar
        btmp = cls._extract_btmp(node)
        return ftime, btime, x, xbar, ftmp, btmp

    @staticmethod
    def _extract_input(graph: Graph) -> Tuple[Tensor, ...]:
        """Extract input tensors from a Graph"""
        input_tensors = []
        for node in graph.nodes:
            if node.op == 'placeholder':
                input_tensors.append(node.meta['fwd_out'])
        return input_tensors

    @staticmethod
    def _extract_unused_output(node: Node) -> int:
        """Extract unused output from `torch.fx.Node`"""
        return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node)

    @staticmethod
    def _extract_btmp(node: List[Node]) -> int:
        """Extract btmp from a list of nodes"""

        def _extract_deps_size():
            deps_size = 0
            for k, v in deps.items():
                k: Node
                if v > 0:
                    deps_size += k.meta['bwd_mem_out']
                if v == float('-inf'):
                    deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)

            return deps_size

        btmp = 0
        deps = {}
        for n in reversed(node):
            deps[n] = len(n.all_input_nodes)
            btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp'])
            for child in n.users:
                if child in deps:
                    deps[child] -= 1
                    if deps[child] <= 0:
                        deps[child] = float('-inf')    # free
        return btmp

    @staticmethod
    def _compute_table(chain: Chain, mmax: int) -> Tuple:
        """Compute the table using dynamic programming. Returns the cost table and the backtracking pointer.

        Args:
            chain (Chain): A basic linearized structure for solving the dynamic programming problem.
            mmax (int): Maximum number of memory slots.

        Returns:
            cost_table (List): cost_table[m][lhs][rhs] indicates the optimal cost of the subproblem from lhs to rhs
            with m memory slots.
            back_ptr (List): back_ptr[m][lhs][rhs] indicates the best operation at this point. It is (True,) if the optimal choice
            is a chain checkpoint, it is (False, j) if the optimal choice is a leaf checkpoint of length j
        """

        ftime = chain.ftime + [0.0]
        btime = chain.btime
        x = chain.x + [0]
        xbar = chain.xbar + [0]
        ftmp = chain.ftmp + [0]
        btmp = chain.btmp + [0]

        # Build table
        cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
        back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]

        # Initialize corner cases where length of sequence equals to 1, i.e. lhs == rhs
        for m in range(mmax + 1):
            for i in range(len(chain) + 1):
                limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i])
                if m >= limit:
                    cost_table[m][i][i] = ftime[i] + btime[i]
                else:
                    cost_table[m][i][i] = float("inf")

        # Compute tables
        for m in range(mmax + 1):
            for d in range(1, len(chain) + 1):
                for i in range(len(chain) + 1 - d):
                    idx = i + d
                    mmin = x[idx + 1] + x[i + 1] + ftmp[i]
                    if idx > i + 1:
                        mmin = max(mmin, x[idx + 1] + max(x[j] + x[j + 1] + ftmp[j] for j in range(i + 1, idx)))
                    if m < mmin:
                        cost_table[m][i][idx] = float("inf")
                    else:
                        leaf_checkpoints = [(j,
                                             sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
                                            for j in range(i + 1, idx + 1)
                                            if m >= x[j]]
                        if leaf_checkpoints:
                            best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
                        else:
                            best_leaf = None
                        if m >= xbar[i + 1]:
                            chain_checkpoint = cost_table[m][i][i] + cost_table[m - xbar[i + 1]][i + 1][idx]
                        else:
                            chain_checkpoint = float("inf")
                        if best_leaf and best_leaf[1] <= chain_checkpoint:
                            cost_table[m][i][idx] = best_leaf[1]
                            back_ptr[m][i][idx] = (False, best_leaf[0])
                        else:
                            cost_table[m][i][idx] = chain_checkpoint
                            back_ptr[m][i][idx] = (True,)
        return cost_table, back_ptr

    @staticmethod
    def _compute_table_c(chain: Chain, mmax: int) -> Tuple:
        try:
            from .rotorc import compute_table

        # build module if module not found
        except ModuleNotFoundError:
            import os
            import subprocess
            import sys
            logger = get_dist_logger()
            logger.info("rotorc hasn't been built! Building library...", ranks=[0])
            this_dir = os.path.dirname(os.path.abspath(__file__))
            result = subprocess.Popen(
                [
                    f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
                    f"--build-lib={this_dir}"
                ],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
            )
            if result.wait() == 0:
                logger.info("rotorc has been built!", ranks=[0])
                from .rotorc import compute_table
            else:
                logger.warning("rotorc built failed! Using python version!", ranks=[0])
                return CheckpointSolverRotor._compute_table(chain, mmax)
        return compute_table(chain, mmax)

    @staticmethod
    def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any],
                   back_ptr: List[Any]) -> "Sequence":
        """Backtrack the cost table and retrieve the optimal checkpointing strategy.

        Args:
            chain (Chain): A basic linearized structure for solving the dynamic programming problem.
            lhs (int): The left index of the interval to backtrack.
            rhs (int): The right index of the interval to backtrack.
            budget (int): The memory budget for processing this interval.
            cost_table (List[Any]): See ``._compute_table()`` for definitions
            back_ptr (List[Any]): See ``._compute_table()`` for definitions

        Raises:
            ValueError: Can not process the chain.

        Returns:
            sequence (Sequence): The sequence of executing nodes with checkpoints.
        """
        if budget <= 0:
            raise ValueError(f"Can not process a chain with negative memory {budget}")
        elif cost_table[budget][lhs][rhs] == float("inf"):
            raise ValueError(f"Can not process this chain from index {lhs} to {rhs} with memory {budget}")

        sequence = Sequence()
        if rhs == lhs:
            if lhs == len(chain):
                sequence += [Loss()]
            else:
                sequence += [ForwardEnable(lhs), Backward(lhs)]
            return sequence

        if back_ptr[budget][lhs][rhs][0]:
            sequence += [
                ForwardEnable(lhs),
                CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table,
                                                 back_ptr),
                Backward(lhs),
            ]
        else:
            best_leaf = back_ptr[budget][lhs][rhs][1]
            sequence += [ForwardCheck(lhs)]
            sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)]
            sequence += [
                CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table,
                                                 back_ptr),
                CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr),
            ]
        return sequence

    @staticmethod
    def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
        """Annotate the nodes in the ``node_list`` with activation checkpoint from the sequence.

        Args:
            sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.
            node_list (List[List[Node]]): The list of nodes to annotate.
        """
        op_list = sequence.list_operations()
        loss_op = next(op for op in op_list if isinstance(op, Loss))
        fwd_list = op_list[:op_list.index(loss_op)]
        bwd_list = op_list[op_list.index(loss_op) + 1:]
        ckpt_idx = 0
        in_ckpt = False
        ckpt_region = []

        # forward annotation
        for idx, op in enumerate(fwd_list, 0):
            if in_ckpt:
                if isinstance(op, ForwardNograd):
                    ckpt_region.append(idx)

                elif isinstance(op, ForwardEnable):
                    in_ckpt = False
                    for node_idx in ckpt_region:
                        for n in node_list[node_idx]:
                            n.meta['activation_checkpoint'] = [ckpt_idx]

                    ckpt_idx += 1
                    ckpt_region = []

                elif isinstance(op, ForwardCheck):
                    for node_idx in ckpt_region:
                        for n in node_list[node_idx]:
                            n.meta['activation_checkpoint'] = [ckpt_idx]

                    ckpt_idx += 1
                    ckpt_region = [idx]

            else:
                if isinstance(op, ForwardCheck):
                    in_ckpt = True
                    ckpt_region.append(idx)

        # annotate the backward if there is any nested activation checkpoint
        in_recompute = False
        for op in bwd_list:
            if in_recompute:
                if isinstance(op, ForwardNograd):
                    ckpt_region.append(op.index)

                elif isinstance(op, ForwardEnable):
                    for node_idx in ckpt_region:
                        for n in node_list[node_idx]:
                            n.meta['activation_checkpoint'].append(ckpt_idx)

                    ckpt_idx += 1
                    ckpt_region = []

                elif isinstance(op, ForwardCheck):
                    for node_idx in ckpt_region:
                        for n in node_list[node_idx]:
                            n.meta['activation_checkpoint'].append(ckpt_idx)

                    ckpt_idx += 1
                    ckpt_region = [op.index]

                elif isinstance(op, Backward):
                    for node_idx in ckpt_region:
                        for n in node_list[node_idx]:
                            n.meta['activation_checkpoint'].append(ckpt_idx)

                    in_recompute = False

            else:
                if not isinstance(op, Backward):
                    in_recompute = True
                    ckpt_idx = 0
                    ckpt_region = []
                    if isinstance(op, ForwardCheck):
                        ckpt_region.append(op.index)

        # postprocess, make sure every activation checkpoint label in the
        # same activation checkpoint region (level = 0) has the same length
        op_list = []
        for node in node_list:
            op_list += node
        ckpt_regions = _find_nested_ckpt_regions(op_list)
        for (start_idx, end_idx) in ckpt_regions:
            nested_length = max(
                len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1))
            for idx in range(start_idx, end_idx + 1):
                op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length -
                                                                        len(op_list[idx].meta['activation_checkpoint']))