from .estimate_memory import EstimateMemory
from .reorder_graph import ReorderGraph
from .trace_indice import TraceIndice
from .utils import NodeMgr, is_non_compute_node


class SelectChunk(object):
    def __init__(
        self,
        trace_indice: TraceIndice,
        estimate_memory: EstimateMemory,
        reorder_graph: ReorderGraph,
        node_mgr: NodeMgr,
        max_memory=None,
    ):
        self.trace_indice = trace_indice
        self.estimate_memory = estimate_memory
        self.reorder_graph = reorder_graph
        self.node_mgr = node_mgr
        if max_memory is not None:
            self.stratge = "fit_memory"
            self.max_memory = max_memory  # MB
        else:
            self.stratge = "min_memory"

    def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak):
        if self.stratge == "min_memory":
            best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos)
        elif self.stratge == "fit_memory":
            best_region = self._select_fit_memory_chunk_region(possible_chunk_regions, chunk_infos, mem_peak)
        else:
            raise RuntimeError()
        return best_region

    def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak):
        # stop chunk if max memory satisfy memory limit
        if max(mem_peak) < self.max_memory:
            return None

        # remove illegal regions
        illegal_regions = []
        for i in possible_chunk_regions:
            if not self._is_legal_region(i, chunk_infos):
                illegal_regions.append(i)
        for i in illegal_regions:
            if i in possible_chunk_regions:
                possible_chunk_regions.remove(i)

        if len(possible_chunk_regions) == 0:
            return None

        # get mem for chunk region
        regions_dict = []
        for region in possible_chunk_regions:
            cur_region = region.copy()
            cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
            cur_chunk_infos = chunk_infos + [cur_region]
            cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
            cur_chunk_region_peak = cur_mem[cur_region["region"][0] : cur_region["region"][1] + 1]
            cur_chunk_region_max_peak = max(cur_chunk_region_peak)
            if cur_chunk_region_max_peak < self.max_memory:
                regions_dict.append(
                    {
                        "chunk_info": region,
                        "chunk_max_mem": cur_chunk_region_max_peak,
                        "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
                        "reorder_chunk_info": cur_region,
                        "reorder_node_list": cur_node_list,
                    }
                )
        # no region found
        if len(regions_dict) == 0:
            raise RuntimeError("Search failed. Try a larger memory threshold.")

        # select the min chunk len
        chunk_len = [i["chunk_len"] for i in regions_dict]
        best_region_idx = chunk_len.index(min(chunk_len))
        best_region = regions_dict[best_region_idx]

        # get max chunk size
        best_region = self._get_fit_chunk_size(best_region, chunk_infos)
        return best_region

    def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos):
        chunk_size = 1
        reorder_chunk_info = chunk_region_dict["reorder_chunk_info"]
        reorder_chunk_info["chunk_size"] = chunk_size
        cur_chunk_max_mem = 0
        # search a region
        while cur_chunk_max_mem < self.max_memory:
            chunk_size *= 2
            reorder_chunk_info["chunk_size"] = chunk_size
            cur_chunk_infos = chunk_infos + [reorder_chunk_info]
            cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
                chunk_region_dict["reorder_node_list"], cur_chunk_infos
            )[0]
            cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1])
        # search exact size
        chunk_info = chunk_region_dict["chunk_info"]
        chunk_info["chunk_size"] = self._chunk_size_binary_search(
            chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
        )
        return chunk_info

    def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
        if left >= 16:
            gap = 4
        else:
            gap = 1
        chunk_info = chunk_region_dict["reorder_chunk_info"]
        while right >= left + gap:
            mid = int((left + right) / 2 + 0.5)
            chunk_info["chunk_size"] = mid
            cur_chunk_infos = chunk_infos + [chunk_info]
            cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
                chunk_region_dict["reorder_node_list"], cur_chunk_infos
            )[0]
            cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1])
            if cur_chunk_max_mem >= self.max_memory:
                right = mid - gap
            else:
                left = mid + gap
        return left

    def _get_compute_node_num(self, start, end):
        count = 0
        for i in self.node_mgr.get_node_slice_by_idx(start, end + 1):
            if not is_non_compute_node(i):
                count += 1
        return count

    def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
        # remove illegal regions
        illegal_regions = []
        for i in possible_chunk_regions:
            if not self._is_legal_region(i, chunk_infos):
                illegal_regions.append(i)
        for i in illegal_regions:
            if i in possible_chunk_regions:
                possible_chunk_regions.remove(i)

        if len(possible_chunk_regions) == 0:
            return None

        # get max possible chunk region
        max_possible_chunk_region = (
            min([i["region"][0] for i in possible_chunk_regions]),
            max([i["region"][1] for i in possible_chunk_regions]),
        )

        # get mem for chunk region
        regions_dict_list = []
        for region in possible_chunk_regions:
            cur_region = region.copy()
            cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
            cur_chunk_infos = chunk_infos + [cur_region]
            cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
            cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0] : max_possible_chunk_region[1] + 1]
            cur_chunk_region_max_peak = max(cur_chunk_region_peak)
            regions_dict_list.append(
                {
                    "chunk_info": region,
                    "chunk_max_mem": cur_chunk_region_max_peak,
                    "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
                    "reorder_chunk_info": cur_region,
                    "reorder_node_list": cur_node_list,
                }
            )

        # select the min mem
        chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list]
        best_region_idx = chunk_max_mem.index(min(chunk_max_mem))
        best_region = regions_dict_list[best_region_idx]["chunk_info"]
        if best_region is not None:
            best_region["chunk_size"] = 1
        return best_region

    def _is_legal_region(self, cur_chunk_info, chunk_infos):
        (chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
        if cur_chunk_info in chunk_infos:
            return False
        if chunk_region_end < chunk_region_start:
            return False
        for i in chunk_infos:
            region = i["region"]
            if not (
                (chunk_region_start > region[1] and chunk_region_end > region[1])
                or (chunk_region_start < region[0] and chunk_region_end < region[0])
            ):
                return False
        return True