import copy
from typing import Dict, List, Tuple

from torch.fx.node import Node

from .estimate_memory import EstimateMemory
from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder


class SearchChunk(object):
    """
    This is the core class for AutoChunk.

    It defines the framework of the strategy of AutoChunk.
    Chunks will be selected one by one utill search stops.

    The chunk search is as follows:
    1. find the peak memory node
    2. find the max chunk region according to the peak memory node
    3. find all possible chunk regions in the max chunk region
    4. find the best chunk region for current status
    5. goto 1

    Attributes:
        gm: graph model
        print_mem (bool): print estimated memory
        trace_index: trace the flow of every dim of every node to find all free dims
        trace_flow: determine the region chunk strategy
        reorder_graph: reorder nodes to improve chunk efficiency
        estimate_memory: estimate memory with chunk
        select_chunk: select the best chunk region

    Args:
        gm: graph model
        max_memory (int): max memory in MB
        print_mem (bool): print estimated memory
    """

    def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
        self.print_mem = print_mem
        self.max_memory = max_memory
        self.print_progress = print_progress
        self.node_mgr = NodeMgr(list(gm.graph.nodes))
        self.trace_indice = TraceIndice(self.node_mgr)
        self.estimate_memory = EstimateMemory()
        self._init_trace()
        self.trace_flow = TraceFlow(self.trace_indice, self.node_mgr)
        self.reorder_graph = ReorderGraph(self.trace_indice, self.node_mgr)
        self.select_chunk = SelectChunk(
            self.trace_indice,
            self.estimate_memory,
            self.reorder_graph,
            self.node_mgr,
            max_memory=max_memory,
        )

    def _init_trace(self) -> None:
        """
        find the max trace range for every node
        reduce the computation complexity of trace_indice
        """
        # find all max ranges
        active_nodes = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())[2]
        # set trace range and do the trace
        if self.print_progress:
            get_logger().info("AutoChunk start tracing indice")
        self.trace_indice.set_active_nodes(active_nodes)
        self.trace_indice.trace_indice()

    def _find_peak_region(self, mem_peak: List) -> int:
        """
        find peak node, along with its neighbour nodes exceeds max mem
        """
        max_value = max(mem_peak)
        max_idx = mem_peak.index(max_value)
        peak_region = [max_idx, max_idx]
        if self.max_memory is None:
            return peak_region

        # to left
        count = 0
        for i in range(max_idx - 1, -1, -1):
            if mem_peak[i] > self.max_memory:
                peak_region[0] = i
            else:
                count += 1
            if count >= 3:
                break
        # to right
        count = 0
        for i in range(max_idx + 1, len(mem_peak) - 1):
            if mem_peak[i] > self.max_memory:
                peak_region[1] = i
                count = 0
            else:
                count += 1
            if count >= 3:
                break

        return peak_region

    def _search_max_chunk_region(self, active_node: List, peak_region: int, chunk_regions: List = None) -> Tuple:
        """
        Search max chunk region according to peak memory node

        Chunk region starts extending from the peak node, stops where free var num is min

        Args:
            active_node (List): active node status for every node
            peak_node_idx (int): peak memory node idx
            chunk_regions (List): chunk region infos

        Returns:
            chunk_region_start (int)
            chunk_region_end (int)
        """
        # check if peak node already in chunkinfo
        if chunk_regions is not None:
            for i in chunk_regions:
                if i["region"][0] < peak_region[0] <= i["region"][1] or \
                    i["region"][0] < peak_region[1] <= i["region"][1]:
                    return None

        active_node_num = [len(i) for i in active_node]
        window_size = 100
        # search min for start
        min_num = 1e4
        for i in range(peak_region[0], max(peak_region[0] - window_size, -1), -1):
            if active_node_num[i] < min_num:
                min_num = active_node_num[i]
                chunk_region_start = i
        # search min for end
        min_num = 1e4
        for i in range(peak_region[1], min(peak_region[1] + window_size, len(active_node_num))):
            if active_node_num[i] < min_num:
                min_num = active_node_num[i]
                chunk_region_end = i

        # avoid chunk regions overlap
        if chunk_regions is not None:
            for i in chunk_regions:
                region = i["region"]
                if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
                    return None
                elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
                    chunk_region_start = region[1] + 1
                elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
                    chunk_region_end = region[0] - 1
        return chunk_region_start, chunk_region_end

    def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
        """
        Find chunk info for a region.

        We are given the region start and region end, and need to find out all chunk info for it.
        We first loop every dim of start node and end node, to see if we can find dim pair,
        which is linked in a flow and not computed.
        If found, we then search flow in the whole region to find out all chunk infos.

        Args:
            input_trace (List): node's input trace in region
            output_trace (List): node's output trace in region
            start_idx (int): region start node index
            end_idx (int): region end node index

        Returns:
            chunk_infos: possible regions found
        """
        start_traces = input_trace[start_idx]
        if len(start_traces) > 1:    # TODO need to be removed
            return []
        end_trace = output_trace[end_idx]
        end_node = self.node_mgr.get_node_by_idx(end_idx)

        chunk_infos = []
        for end_dim, _ in enumerate(end_trace["indice"]):
            for start_node, start_trace in start_traces.items():
                for start_dim, _ in enumerate(start_trace["indice"]):
                    if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim,
                                                                  end_idx):
                        continue
                    # flow search
                    chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
                    if chunk_info is None:
                        continue
                    chunk_infos.append(chunk_info)
        return chunk_infos

    def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_region: Node) -> List:
        """
        Search every possible region within the max chunk region.

        Args:
            max_chunk_region (Tuple)
            peak_node (Node): peak memory node

        Returns:
            possible_chunk_region (List)
        """
        possible_chunk_region = []
        output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
        input_trace = []    # trace of a node's input nodes
        for _, n in enumerate(self.node_mgr.get_node_list()):
            cur_trace = {}
            for arg in n.args:
                if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
                    cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
            input_trace.append(cur_trace)

        for start_idx in range(max_chunk_region[0], peak_region[0] + 1):
            for end_idx in range(peak_region[1], max_chunk_region[1] + 1):
                # skip non compute nodes
                if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
                        self.node_mgr.get_node_by_idx(end_idx)):
                    continue
                # select free dim
                chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
                if len(chunk_info) > 0:
                    possible_chunk_region.extend(chunk_info)
        return possible_chunk_region

    def _step_search(
        self,
        mem_peak: List[float],
        active_node: List[List[Node]],
        chunk_infos: List[Dict],
    ) -> Dict:
        """
        Find one chunk region

        The chunk search is as follows:
        1. find the peak memory node
        2. find the max chunk region according to the peak memory node
        3. find all possible chunk regions in the max chunk region
        4. find the best chunk region for current status

        Args:
            mem_peak (List): peak memory for every node
            active_node (List[List[Node]]): active node for every node
            chunk_infos (List[Dict]): all chunk info

        Returns:
            best_chunk_region (Dict)
        """
        peak_region = self._find_peak_region(mem_peak)
        max_chunk_region = self._search_max_chunk_region(active_node, peak_region, chunk_infos)
        if max_chunk_region == None:
            return None
        possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_region)
        best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, mem_peak)
        best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
        return best_chunk_region

    def search_region(self) -> Dict:
        """
        Search all chunk regions:
        1. Estimate current memory
        2. Find best chunk for current memory
        3. goto 1

        Returns:
            chunk_infos (Dict)
        """
        if self.print_progress:
            get_logger().info("AutoChunk start searching chunk regions")

        chunk_infos = []
        init_mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())
        mem_peak = init_mem_peak

        while True:
            chunk_info = self._step_search(mem_peak, active_node, chunk_infos)
            if chunk_info is None:
                break
            chunk_infos.append(chunk_info)

            mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(
                self.node_mgr.get_node_list(), chunk_infos)

            if self.print_progress:
                get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
                                  (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))

        if self.print_mem:
            self.print_mem = False
            self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
                                                              chunk_infos,
                                                              print_mem=True)
        return chunk_infos