from dataclasses import dataclass
from typing import Dict, List, Tuple, Union

import torch
import torch.distributed as dist

from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh


@dataclass
class DeviceMeshInfo:
    """
    This class is used to store the information used to initialize the device mesh.

    Args:
        physical_ids (List[int]): The physical ids of the current booster. For example, if we have the last 4 GPUs on a 8-devices cluster, then the physical ids should be [4, 5, 6, 7].
        mesh_shapes (List[Union[torch.Size, List[int], Tuple[int]]]): The shape of the mesh. For example, if we have 4 GPUs and we want to use 2D mesh with mesh shape [2, 2], then the mesh shape should be [2, 2].
    """

    physical_ids: List[int]
    mesh_shape: Union[torch.Size, List[int], Tuple[int]] = None

    def __post_init__(self):
        if self.mesh_shape is not None:
            world_size = len(self.physical_ids)
            mesh_shape_numel = torch.Size(self.mesh_shape).numel()
            assert (
                world_size == mesh_shape_numel
            ), f"the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}"


def initialize_device_mesh(device_mesh_info: DeviceMeshInfo):
    """
    This method is used to initialize the device mesh.

    Args:
        device_mesh_info (DeviceMeshInfo): The information used to initialize device mesh.
    """
    # parse the device mesh info
    physical_devices = device_mesh_info.physical_ids
    physical_mesh = torch.tensor(physical_devices)
    logical_mesh_shape = device_mesh_info.mesh_shape

    if logical_mesh_shape is None:
        ab_profiler = AlphaBetaProfiler(physical_devices)
        # search for the best logical mesh shape
        logical_mesh_id = ab_profiler.search_best_logical_mesh()
        logical_mesh_id = torch.Tensor(logical_mesh_id).to(torch.int)

    else:
        logical_mesh_id = physical_mesh.reshape(logical_mesh_shape)

    device_mesh = DeviceMesh(physical_mesh_id=physical_mesh, logical_mesh_id=logical_mesh_id, init_process_group=True)
    return device_mesh


class DeviceMeshManager:
    """
    Device mesh manager is responsible for creating and managing device meshes.
    """

    def __init__(self):
        self.device_mesh_store: Dict[str, DeviceMesh] = dict()

    def create_device_mesh(self, name, device_mesh_info: DeviceMeshInfo) -> DeviceMesh:
        """
        Create a device mesh and store it in the manager.

        Args:
            name (str): name of the device mesh
            device_mesh_info (DeviceMeshInfo): the information used to initialize the device mesh
        """
        if name not in self.device_mesh_store:
            device_mesh = initialize_device_mesh(device_mesh_info)
            self.device_mesh_store[name] = device_mesh
            return device_mesh
        else:
            raise ValueError(f"Device mesh {name} already exists.")

    def get(self, name: str) -> DeviceMesh:
        """
        Get a device mesh by name.

        Args:
            name (str): name of the device mesh

        Returns:
            DeviceMesh: the device mesh
        """
        if name in self.device_mesh_store:
            return self.device_mesh_store[name]
        else:
            raise ValueError(f"Device mesh {name} does not exist.")

    def destroy(self, name: str) -> None:
        """
        Destroy a device mesh by name.

        Args:
            name (str): name of the device mesh
        """
        if name in self.device_mesh_store:
            for pgs in self.device_mesh_store[name].process_groups_dict.values():
                for pg in pgs:
                    dist.destroy_process_group(pg)
            del self.device_mesh_store[name]
        else:
            raise ValueError(f"Device mesh {name} does not exist.")

    def destroy_all(self):
        """
        Destroy all device meshes.
        """
        for name in self.device_mesh_store:
            for pgs in self.device_mesh_store[name].process_groups_dict.values():
                for pg in pgs:
                    dist.destroy_process_group(pg)

        self.device_mesh_store.clear()