import math
import time
from typing import Dict, List, Tuple

import torch
import torch.distributed as dist

from colossalai.logging import get_dist_logger

GB = int((1 << 30))
BYTE = 4
FRAMEWORK_LATENCY = 0


class AlphaBetaProfiler:
    '''
    Profile alpha and beta value for a given device list.

    Usage:
        # Note: the environment of execution is supposed to be
        # multi-process with multi-gpu in mpi style.
        >>> physical_devices = [0, 1, 4, 5]
        >>> ab_profiler = AlphaBetaProfiler(physical_devices)
        >>> ab_dict = profiler.alpha_beta_dict
        >>> print(ab_dict)
        {(0, 1): (1.9641406834125518e-05, 4.74049549614719e-12), (0, 4): (1.9506998360157013e-05, 6.97421973297474e-11), (0, 5): (2.293858677148819e-05, 7.129930361393644e-11),
         (1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12),
         (1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11),
         (4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)}
    '''

    def __init__(self,
                 physical_devices: List[int],
                 alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None,
                 ctype: str = 'a',
                 warmup: int = 5,
                 repeat: int = 25,
                 latency_iters: int = 5,
                 homogeneous_tolerance: float = 0.1):
        '''
        Args:
            physical_devices: A list of device id, each element inside it is the global rank of that device.
            alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs.
            ctype: 'a' for all-reduce, 'b' for broadcast.
            warmup: Number of warmup iterations.
            repeat: Number of iterations to measure.
            latency_iters: Number of iterations to measure latency.
        '''
        self.physical_devices = physical_devices
        self.ctype = ctype
        self.world_size = len(physical_devices)
        self.warmup = warmup
        self.repeat = repeat
        self.latency_iters = latency_iters
        self.homogeneous_tolerance = homogeneous_tolerance
        self.process_group_dict = None
        self._init_profiling()
        if alpha_beta_dict is None:
            self.alpha_beta_dict = self.profile_ab()
        else:
            self.alpha_beta_dict = alpha_beta_dict

    def _init_profiling(self):
        # Create process group list based on its global rank
        process_group_list = []
        for f_index in range(self.world_size - 1):
            for b_index in range(f_index + 1, self.world_size):
                process_group_list.append((self.physical_devices[f_index], self.physical_devices[b_index]))

        # Create process group dict which maps process group to its handler
        process_group_dict = {}
        for process_group in process_group_list:
            pg_handler = dist.new_group(process_group)
            process_group_dict[process_group] = pg_handler

        self.process_group_dict = process_group_dict

    def _profile(self, process_group, pg_handler, nbytes):
        logger = get_dist_logger()
        rank = dist.get_rank()
        src_device_num = process_group[0]
        world_size = len(process_group)

        device = torch.cuda.current_device()
        buf = torch.randn(nbytes // 4).to(device)

        torch.cuda.synchronize()
        # warmup
        for _ in range(self.warmup):
            if self.ctype == "a":
                dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=pg_handler)
            elif self.ctype == "b":
                dist.broadcast(buf, src=src_device_num, group=pg_handler)
        torch.cuda.synchronize()

        dist.barrier(group=pg_handler)
        begin = time.perf_counter()
        for _ in range(self.repeat):
            if self.ctype == "a":
                dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=pg_handler)
            elif self.ctype == "b":
                dist.broadcast(buf, src=src_device_num, group=pg_handler)
        torch.cuda.synchronize()
        end = time.perf_counter()
        dist.barrier(group=pg_handler)

        if rank == src_device_num:
            avg_time_s = (end - begin) / self.repeat - FRAMEWORK_LATENCY
            alg_band = nbytes / avg_time_s
            if self.ctype == "a":
                # convert the bandwidth of all-reduce algorithm to the bandwidth of the hardware.
                bus_band = 2 * (world_size - 1) / world_size * alg_band
                bus_band = alg_band
            elif self.ctype == "b":
                bus_band = alg_band

            logger.info(
                f"GPU:{rank}, Bytes: {nbytes} B,Time: {round(avg_time_s * 1e6,2)} us, Bus bandwidth: {round(bus_band / GB,2)} GB/s"
            )
            return (avg_time_s, alg_band)
        else:
            # Just a placeholder
            return (None, None)

    def profile_latency(self, process_group, pg_handler):
        '''
        This function is used to profile the latency of the given process group with a series of bytes.

        Args:
            process_group: A tuple of global rank of the process group.
            pg_handler: The handler of the process group.

        Returns:
            latency: None if the latency is not measured, otherwise the median of the latency_list.
        '''
        latency_list = []
        for i in range(self.latency_iters):
            nbytes = int(BYTE << i)
            (t, _) = self._profile(process_group, pg_handler, nbytes)
            latency_list.append(t)

        if latency_list[0] is None:
            latency = None
        else:
            median_index = math.floor(self.latency_iters / 2)
            latency = latency_list[median_index]

        return latency

    def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)):
        '''
        This function is used to profile the bandwidth of the given process group.

        Args:
            process_group: A tuple of global rank of the process group.
            pg_handler: The handler of the process group.
        '''
        (_, bandwidth) = self._profile(process_group, pg_handler, maxbytes)
        return bandwidth

    def profile_ab(self):
        '''
        This method is used to profiling the alpha and beta value for a given device list.

        Returns:
            alpha_beta_dict: A dict which maps process group to its alpha and beta value.
        '''
        alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {}
        rank = dist.get_rank()
        global_pg_handler = dist.new_group(self.physical_devices)

        def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup):
            assert rank in process_group
            device = torch.cuda.current_device()
            rank_max_nbytes = torch.cuda.mem_get_info(device)[0]
            rank_max_nbytes = torch.tensor(rank_max_nbytes, device=device)
            dist.all_reduce(rank_max_nbytes, op=dist.ReduceOp.MIN, group=pg_handler)
            max_nbytes = min(int(1 * GB), int(GB << int(math.log2(rank_max_nbytes.item() / GB))))
            return max_nbytes

        for process_group, pg_handler in self.process_group_dict.items():
            if rank not in process_group:
                max_nbytes = None
                alpha = None
                bandwidth = None
            else:
                max_nbytes = get_max_nbytes(process_group, pg_handler)
                alpha = self.profile_latency(process_group, pg_handler)
                bandwidth = self.profile_bandwidth(process_group, pg_handler, maxbytes=max_nbytes)

            if bandwidth is None:
                beta = None
            else:
                beta = 1 / bandwidth

            broadcast_list = [alpha, beta]
            dist.broadcast_object_list(broadcast_list, src=process_group[0])
            alpha_beta_dict[process_group] = tuple(broadcast_list)

        # add symmetry pair to the alpha_beta_dict
        symmetry_ab_dict = {}
        for process_group, alpha_beta_pair in alpha_beta_dict.items():
            symmetry_process_group = (process_group[1], process_group[0])
            symmetry_ab_dict[symmetry_process_group] = alpha_beta_pair

        alpha_beta_dict.update(symmetry_ab_dict)

        return alpha_beta_dict

    def search_best_logical_mesh(self):
        '''
        This method is used to search the best logical mesh for the given device list.

        The best logical mesh is searched in following steps:
            1. detect homogeneous device groups, we assume that the devices in the alpha_beta_dict
                are homogeneous if the beta value is close enough.
            2. Find the best homogeneous device group contains all the physical devices. The best homogeneous
                device group means the lowest beta value in the groups which contains all the physical devices.
                And the reason we require the group contains all the physical devices is that the devices not in
                the group will decrease the bandwidth of the group.
            3. If the best homogeneous device group is found, we will construct the largest ring for each device
                based on the best homogeneous device group, and the best logical mesh will be the union of all the
                rings. Otherwise, the best logical mesh will be the balanced logical mesh, such as shape (2, 2) for
                4 devices.

        Returns:
            best_logical_mesh: The best logical mesh for the given device list.

        Usage:
            >>> physical_devices = [0, 1, 2, 3]
            >>> ab_profiler = AlphaBetaProfiler(physical_devices)
            >>> best_logical_mesh = profiler.search_best_logical_mesh()
            >>> print(best_logical_mesh)
            [[0, 1], [2, 3]]
        '''

        def _power_of_two(integer):
            return integer & (integer - 1) == 0

        def _detect_homogeneous_device(alpha_beta_dict):
            '''
            This function is used to detect whether the devices in the alpha_beta_dict are homogeneous.

            Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value
                of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)]
                * base_beta.
            '''
            homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {}
            for process_group, (_, beta) in alpha_beta_dict.items():
                if homogeneous_device_dict is None:
                    homogeneous_device_dict[beta] = []
                    homogeneous_device_dict[beta].append(process_group)

                match_beta = None
                for beta_value in homogeneous_device_dict.keys():
                    if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * (
                            1 - self.homogeneous_tolerance):
                        match_beta = beta_value
                        break

                if match_beta is not None:
                    homogeneous_device_dict[match_beta].append(process_group)
                else:
                    homogeneous_device_dict[beta] = []
                    homogeneous_device_dict[beta].append(process_group)

            return homogeneous_device_dict

        def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]):
            '''
            This function is used to check whether the homogeneous_group contains all physical devices.
            '''
            flatten_mesh = []
            for process_group in homogeneous_group:
                flatten_mesh.extend(process_group)
            non_duplicated_flatten_mesh = set(flatten_mesh)
            return len(non_duplicated_flatten_mesh) == len(self.physical_devices)

        def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
            '''
            This function is used to construct the largest ring in the homogeneous_group for each rank.
            '''
            # Construct the ring
            ring = []
            ranks_in_ring = []
            for rank in self.physical_devices:
                if rank in ranks_in_ring:
                    continue
                stable_status = False
                ring_for_rank = []
                ring_for_rank.append(rank)
                check_rank_list = [rank]
                rank_to_check_list = []

                while not stable_status:
                    stable_status = True
                    check_rank_list.extend(rank_to_check_list)
                    rank_to_check_list = []
                    for i in range(len(check_rank_list)):
                        check_rank = check_rank_list.pop()
                        for process_group in homogeneous_group:
                            if check_rank in process_group:
                                rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1]
                                if rank_to_append not in ring_for_rank:
                                    stable_status = False
                                    rank_to_check_list.append(rank_to_append)
                                    ring_for_rank.append(rank_to_append)

                ring.append(ring_for_rank)
                ranks_in_ring.extend(ring_for_rank)

            return ring

        assert _power_of_two(self.world_size)
        power_of_two = int(math.log2(self.world_size))
        median = power_of_two // 2
        balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median))
        row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1]
        balanced_logical_mesh = []
        for row_index in range(row_size):
            balanced_logical_mesh.append([])
            for column_index in range(column_size):
                balanced_logical_mesh[row_index].append(self.physical_devices[row_index * column_size + column_index])

        homogeneous_device_dict = _detect_homogeneous_device(self.alpha_beta_dict)
        beta_list = [b for b in homogeneous_device_dict.keys()]
        beta_list.sort()
        beta_list.reverse()
        homogeneous_types = len(beta_list)
        best_logical_mesh = None
        if homogeneous_types >= 2:
            for _ in range(homogeneous_types - 1):
                lowest_beta = beta_list.pop()
                best_homogeneous_group = homogeneous_device_dict[lowest_beta]
                # if the best homogeneous group contains all physical devices,
                # we will build the logical device mesh based on it. Otherwise,
                # we will check next level homogeneous group.
                if _check_contain_all_devices(best_homogeneous_group):
                    # We choose the largest ring for each rank to maximum the best bus utilization.
                    best_logical_mesh = _construct_largest_ring(best_homogeneous_group)
                    break

        if homogeneous_types == 1 or best_logical_mesh is None:
            # in this case, we use balanced logical mesh as the best
            # logical mesh.
            best_logical_mesh = balanced_logical_mesh

        return best_logical_mesh

    def extract_alpha_beta_for_device_mesh(self):
        '''
        Extract the mesh_alpha list and mesh_beta list based on the
            best logical mesh, which will be used to initialize the device mesh.

        Usage:
            >>> physical_devices = [0, 1, 2, 3]
            >>> ab_profiler = AlphaBetaProfiler(physical_devices)
            >>> mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh()
            >>> print(mesh_alpha)
            [2.5917552411556242e-05, 0.00010312341153621673]
            >>> print(mesh_beta)
            [5.875573704655635e-11, 4.7361584445959614e-12]
        '''
        best_logical_mesh = self.search_best_logical_mesh()

        first_axis = [row[0] for row in best_logical_mesh]
        second_axis = best_logical_mesh[0]

        # init process group for both axes
        first_axis_process_group = dist.new_group(first_axis)
        second_axis_process_group = dist.new_group(second_axis)

        # extract alpha and beta for both axes
        def _extract_alpha_beta(pg, pg_handler):
            latency = self.profile_latency(pg, pg_handler)
            bandwidth = self.profile_bandwidth(pg, pg_handler)
            broadcast_object = [latency, bandwidth]
            dist.broadcast_object_list(broadcast_object, src=pg[0])
            return broadcast_object

        first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group)
        second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group)
        mesh_alpha = [first_latency, second_latency]
        # The beta values have been enlarged by 1e10 times temporarily because the computation cost
        # is still estimated in the unit of TFLOPs instead of time. We will remove this factor in future.
        mesh_beta = [1e10 / first_bandwidth, 1e10 / second_bandwidth]

        return mesh_alpha, mesh_beta