import operator
from functools import reduce

import torch

from colossalai.device.device_mesh import DeviceMesh

from .utils import merge_same_dim_mesh_list

__all__ = ["_DimSpec", "ShardingException", "ShardingSpec"]

ALLGATHER_COST = 20
SHARD_COST = 5
STEP_PENALTY = 6
NAN = "nan"


class _DimSpec:
    """
    Sharding spec for single dimension of the sharded tensor describe the sharding dimension of
    logical device mesh and give a method to compute the difference between them.
    This class is used internally in ShardingSpec.

    Argument:
        shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
            Otherwise, the element in shard_list means the data will be sharded in that dimension.
    """

    _DIFFERENCE_DICT = None

    def __init__(self, shard_list):
        self.is_replica = len(shard_list) == 0
        self.shard_list = shard_list

    def __eq__(self, other):
        return str(self) == str(other)

    def __repr__(self):
        if self.is_replica:
            return "R"
        target = "S"
        for dim in self.shard_list:
            target += str(dim)
        return target

    @property
    def difference_dict(self):
        """
        Returns the difference dict, and lazily initializes it when needed

        Return:
            difference_dict(Dict[Tuple[int, int], Union[int, float, str]]):
                difference dict
        """
        if self._DIFFERENCE_DICT is None:
            self._DIFFERENCE_DICT = self._build_difference_2d_dict()

        return self._DIFFERENCE_DICT

    def difference(self, other):
        """
        The difference between two _DimSpec.

        Argument:
            other(_DimSpec): the dim spec to compare with.

        Return:
            difference(int): the difference between two _DimSpec.

        Example:
            dim_spec = _DimSpec([0])
            other_dim_spec = _DimSpec([0, 1])
            print(dim_spec.difference(other_dim_spec))

        Output:
            5
        """
        difference = self.difference_dict[(str(self), str(other))]
        return difference

    @classmethod
    def _build_difference_2d_dict(cls):
        """
        Build a difference mapping for 2D device mesh case. It will be used to
        compute the difference between _DimSpec pairs.
        """

        source_spec_list = ["R", "S0", "S1", "S01"]
        target_spec_list = ["R", "S0", "S1", "S01"]
        difference_dict = {}
        for source_spec in source_spec_list:
            for target_spec in target_spec_list:
                source_shard_list = cls._convert_str_to_shard_list(source_spec)
                target_shard_list = cls._convert_str_to_shard_list(target_spec)

                # source same as target
                if source_shard_list == target_shard_list:
                    difference = 0

                # all_gather(source) -> target
                elif (
                    len(source_shard_list) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list
                ):
                    difference = ALLGATHER_COST

                # shard(source) -> target
                elif (
                    len(source_shard_list) == len(target_shard_list) - 1
                    and source_shard_list == target_shard_list[:-1]
                    and target_shard_list[-1] not in source_shard_list
                ):
                    difference = SHARD_COST

                # S1 -> S0 or S0 -> S1
                elif len(source_shard_list) == len(target_shard_list):
                    # source -> R -> target
                    difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST

                # R -> S01
                elif len(source_shard_list) == len(target_shard_list) - 2:
                    difference = SHARD_COST + STEP_PENALTY + SHARD_COST

                # S01 -> R
                elif len(source_shard_list) == len(target_shard_list) + 2:
                    difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST

                # S1 -> S01
                elif len(source_shard_list) == len(target_shard_list) - 1:
                    difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + STEP_PENALTY + SHARD_COST

                # S01 -> S1
                elif len(source_shard_list) == len(target_shard_list) + 1:
                    difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + STEP_PENALTY + SHARD_COST

                else:
                    difference = NAN
                difference_dict[(source_spec, target_spec)] = difference

        return difference_dict

    @staticmethod
    def _convert_str_to_shard_list(str_spec):
        """
        Convert str_spec into shard_list.

        Argument:
            str_spec(str): dim spec in str type.
        """

        if str_spec == "R":
            return []
        if str_spec == "S0":
            return [0]
        if str_spec == "S1":
            return [1]
        if str_spec == "S01":
            return [0, 1]


class ShardingSpecException(Exception):
    pass


class ShardingOutOfIndexError(ShardingSpecException):
    pass


class DuplicatedShardingDimensionError(ShardingSpecException):
    pass


class ShardingNotDivisibleError(ShardingSpecException):
    pass


class ShardingSpec:
    """
    Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
    to, the entire shape of the tensor before sharded, and the sharding sequence looks like
    [R, R, S0, S1].

    Argument:
        device_mesh(DeviceMesh): A logical view of a physical mesh.
        entire_shape(torch.Size): The entire shape of tensor before sharded.
        dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
            and the value of the key describe which logical axis will be sharded in that dimension.
        sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
    """

    def __init__(
        self, device_mesh: DeviceMesh, entire_shape: torch.Size, dim_partition_dict=None, sharding_sequence=None
    ):
        self.device_mesh = device_mesh

        if isinstance(entire_shape, (list, tuple)):
            entire_shape = torch.Size(entire_shape)
        self.entire_shape = entire_shape
        self.dim_partition_dict = dim_partition_dict
        self.sharding_sequence = sharding_sequence
        if self.sharding_sequence is None:
            assert (
                self.dim_partition_dict is not None
            ), f"dim_partition_dict should not be None, if sharding_sequence is NoneType object."
            self.dim_partition_dict = merge_same_dim_mesh_list(
                dim_size=len(entire_shape), dim_partition_dict=self.dim_partition_dict
            )
            self.convert_dict_to_shard_sequence()
        elif self.dim_partition_dict is None:
            assert (
                self.sharding_sequence is not None
            ), f"sharding_sequence should not be None, if dim_partition_dict is NoneType object."
            self.convert_shard_sequence_to_dict()
        self._sanity_check()

    def __repr__(self):
        res_list = ["DistSpec:"]
        res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence))
        res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.shape}")
        return " ".join(res_list)

    def _sanity_check(self):
        # make sure all axes in logical device mesh only be used once
        dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
        for dim, shard_list in self.dim_partition_dict.items():
            for element in shard_list:
                if element in dim_check_list:
                    dim_check_list.remove(element)
                else:
                    raise DuplicatedShardingDimensionError(
                        f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}."
                    )

        # make sure that the dimension is not out of index
        for dim in self.dim_partition_dict.keys():
            if dim >= len(self.entire_shape):
                raise ShardingOutOfIndexError(
                    f"The dim_partition_dict specifies to shard dimension {dim} but the entire_shape only has {len(self.entire_shape)} dimensions"
                )

        # make sure that the sharding for a dimension is divisible by the number of devices
        for dim, shard_list in self.dim_partition_dict.items():
            tensor_dim_size = self.entire_shape[dim]
            num_devices = 1

            for element in shard_list:
                num_devices *= self.device_mesh.shape[element]

            if tensor_dim_size % num_devices != 0:
                raise ShardingNotDivisibleError(
                    f"The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices."
                )

    def convert_dict_to_shard_sequence(self):
        """
        Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence.
        """
        sharding_sequence = [_DimSpec([])] * len(self.entire_shape)
        for dim, shard_list in self.dim_partition_dict.items():
            sharding_sequence[dim] = _DimSpec(shard_list)
        self.sharding_sequence = sharding_sequence

    def convert_shard_sequence_to_dict(self):
        """
        Convert sharding_sequence into dim_partition_dict.
        """
        new_dim_partition_dict = {}
        for index, dim_spec in enumerate(self.sharding_sequence):
            if not dim_spec.is_replica:
                if index not in new_dim_partition_dict:
                    new_dim_partition_dict[index] = []
                new_dim_partition_dict[index].extend(dim_spec.shard_list)
        self.dim_partition_dict = new_dim_partition_dict

    def sharding_sequence_difference(self, other):
        """
        This function is a naive version of difference computation. It just simply accumulates difference every dimension between the
        pair of sharding sequence.

        Example:
            dim_partition_dict = {0: [0, 1]}
            # DistSpec:
            #     shard_sequence: S01,R,R
            #     device_mesh_shape: (4, 4)
            sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
            dim_partition_dict_to_compare = {0: [0], 1: [1]}
            # DistSpec:
            #     shard_sequence: S0,S1,R
            #     device_mesh_shape: (4, 4)
            sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
            print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))

        Output:
            25

        Argument:
            other(ShardingSpec): The ShardingSpec to compared with.

        Return:
            difference(int): Difference between two ShardingSpec.
        """
        assert len(self.sharding_sequence) == len(
            other.sharding_sequence
        ), f"Cannot compare difference for two sharding specs with different length."
        difference = 0
        for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence):
            difference += orig_dim_spec.difference(other_dim_spec)
        return difference

    def get_sharded_shape_per_device(self):
        sharded_shape = list(self.entire_shape)
        for dim, shard_list in self.dim_partition_dict.items():
            mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
            shard_partitions = reduce(operator.mul, mesh_list, 1)
            assert (
                sharded_shape[dim] % shard_partitions == 0
            ), f"Cannot shard dimension {dim} into {shard_partitions} partitions."
            sharded_shape[dim] //= shard_partitions
        return torch.Size(sharded_shape)