from copy import deepcopy
from typing import Dict, List

from ..utils import merge_same_dim_mesh_list
from .misc import ShardingOutOfIndexError

__all__ = ['DimSpec', 'ShardingException', 'ShardingSpec']

ALLGATHER_COST = 20
SHARD_COST = 5
STEP_PENALTY = 6
NAN = 'nan'


class DimSpec:
    '''
    Sharding spec for single dimension of the sharded tensor describe the sharding dimension of
    logical device mesh and give a method to compute the difference between them.
    This class is used internally in ShardingSpec.

    Argument:
        shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
            Otherwise, the element in shard_list means the data will be sharded in that dimension.
    '''

    def __init__(self, shard_list):
        self.is_replica = len(shard_list) == 0
        self.shard_list = shard_list
        self.build_difference_2d_dict()

    def __eq__(self, other):
        return str(self) == str(other)

    def __repr__(self):
        if self.is_replica:
            return 'R'
        target = 'S'
        for dim in self.shard_list:
            target += str(dim)
        return target

    def _convert_str_to_shard_list(self, str_spec):
        '''
        Convert str_spec into shard_list.

        Argument:
            str_spec(str): dim spec in str type.
        '''

        if str_spec == 'R':
            return []
        if str_spec == 'S0':
            return [0]
        if str_spec == 'S1':
            return [1]
        if str_spec == 'S01':
            return [0, 1]

    def build_difference_2d_dict(self):
        '''
        Build a difference mapping for 2D device mesh case. It will be used to
        compute the difference between DimSpec pairs.
        '''

        source_spec_list = ['R', 'S0', 'S1', 'S01']
        target_spec_list = ['R', 'S0', 'S1', 'S01']
        difference_dict = {}
        for source_spec in source_spec_list:
            for target_spec in target_spec_list:
                legal_sharding_dims = []
                spec_pair = (deepcopy(source_spec), deepcopy(target_spec))
                source_shard_list = self._convert_str_to_shard_list(source_spec)
                target_shard_list = self._convert_str_to_shard_list(target_spec)

                # source same as target
                if source_shard_list == target_shard_list:
                    difference = 0

                # all_gather(source) -> target
                elif len(source_shard_list
                        ) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list:
                    difference = ALLGATHER_COST

                # shard(source) -> target
                elif len(source_shard_list) == len(
                        target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[
                            -1] not in source_shard_list:
                    difference = SHARD_COST

                # S1 -> S0 or S0 -> S1
                elif len(source_shard_list) == len(target_shard_list):
                    # source -> R -> target
                    difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST

                # R -> S01
                elif len(source_shard_list) == len(target_shard_list) - 2:
                    difference = SHARD_COST + STEP_PENALTY + SHARD_COST

                # S01 -> R
                elif len(source_shard_list) == len(target_shard_list) + 2:
                    difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST

                # S1 -> S01
                elif len(source_shard_list) == len(target_shard_list) - 1:
                    difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + STEP_PENALTY + SHARD_COST

                # S01 -> S1
                elif len(source_shard_list) == len(target_shard_list) + 1:
                    difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + STEP_PENALTY + SHARD_COST

                else:
                    difference = NAN
                difference_dict[spec_pair] = difference

        self.difference_dict = difference_dict

    def dim_diff(self, other):
        '''
        The difference between two _DimSpec.

        Argument:
            other(_DimSpec): the dim spec to compare with.

        Return:
            difference(int): the difference between two _DimSpec.

        Example:
            dim_spec = _DimSpec([0])
            other_dim_spec = _DimSpec([0, 1])
            print(dim_spec.difference(other_dim_spec))

        Output:
            5
        '''
        difference = self.difference_dict[(str(self), str(other))]
        return difference


class ShardingSpec:
    '''
    Sharding spec describes how to shard a tensor with dim_size dimensions. The sharding sequence looks like
    [R, R, S0, S1], which means

    Argument:
        dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
            and the value of the key describe which logical axis will be sharded in that dimension.
        sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
    '''

    def __init__(self,
                 dim_size: int,
                 dim_partition_dict: Dict[int, List[int]] = None,
                 sharding_sequence: List[DimSpec] = None):
        self.dims = dim_size
        self.dim_partition_dict = dim_partition_dict
        self.sharding_sequence = sharding_sequence
        if self.sharding_sequence is None:
            assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.'
            self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=self.dims,
                                                               dim_partition_dict=self.dim_partition_dict)
            self.sharding_sequence = self.convert_dict_to_shard_sequence()

        elif self.dim_partition_dict is None:
            assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.'
            self.dim_partition_dict = self.convert_shard_sequence_to_dict()

        self._sanity_check()

    def _sanity_check(self):
        if len(self.sharding_sequence) > self.dims:
            raise ShardingOutOfIndexError(
                f'sharding_sequence should have {self.dims} elements, but got index {len(self.sharding_sequence)}.')

        if list(self.dim_partition_dict.keys()) and max(list(self.dim_partition_dict.keys())) >= self.dims:
            raise ShardingOutOfIndexError(
                f'the key of dim_partition_dict should be less than {self.dims}, but got {max(list(self.dim_partition_dict.keys()))}.'
            )

    def __repr__(self):
        res_list = ["ShardingSpec:"]
        res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence))
        return ' '.join(res_list)

    def convert_dict_to_shard_sequence(self):
        '''
        Convert dim_partition_dict into list of DimSpec, and assign it to sharding_sequence.
        '''
        sharding_sequence = [DimSpec([])] * self.dims
        for dim, shard_list in self.dim_partition_dict.items():
            sharding_sequence[dim] = DimSpec(shard_list)
        return sharding_sequence

    def convert_shard_sequence_to_dict(self):
        '''
        Convert sharding_sequence into dim_partition_dict.
        '''
        new_dim_partition_dict = {}
        for index, dim_spec in enumerate(self.sharding_sequence):
            if not dim_spec.is_replica:
                if index not in new_dim_partition_dict:
                    new_dim_partition_dict[index] = []
                new_dim_partition_dict[index].extend(dim_spec.shard_list)
        return new_dim_partition_dict

    def spec_diff(self, other):
        '''
        This function is a naive version of difference computation. It just simply accumulates difference every dimension between the
        pair of sharding sequence.

        Example:
            dim_partition_dict = {0: [0, 1]}
            # DistSpec:
            #     shard_sequence: S01,R,R
            #     device_mesh_shape: (4, 4)
            sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
            dim_partition_dict_to_compare = {0: [0], 1: [1]}
            # DistSpec:
            #     shard_sequence: S0,S1,R
            #     device_mesh_shape: (4, 4)
            sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
            print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))

        Output:
            25

        Argument:
            other(ShardingSpec): The ShardingSpec to compared with.

        Return:
            difference(int): Difference between two ShardingSpec.
        '''
        assert len(self.sharding_sequence) == len(
            other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.'
        difference = 0
        for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence):
            difference += orig_dim_spec.dim_diff(other_dim_spec)
        return difference