from functools import partial

import pytest
import torch
import torch.multiprocessing as mp

from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port


def check_apply(rank, world_size, port):
    disable_existing_loggers()
    launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')

    physical_mesh_id = torch.arange(0, 4)
    mesh_shape = (2, 2)
    # [[0, 1,
    #  [2, 3]]
    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
    entire_shape = torch.Size((4, 2))
    shape_consistency_manager = ShapeConsistencyManager()
    dim_partition_source = {0: [0]}
    dim_partition_target = {1: [0]}

    # DistSpec:
    #     shard_sequence: S0,R
    #     device_mesh_shape: (2, 2)
    sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)

    # DistSpec:
    #     shard_sequence: R,S0
    #     device_mesh_shape: (2, 2)
    sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)

    if rank in (0, 1):
        sharded_tensor_0 = torch.zeros(2, 1)
        sharded_tensor_1 = torch.ones(2, 1)
        # tensor([[0., 1.],
        #         [0., 1.]])
        tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
    if rank in (2, 3):
        sharded_tensor_0 = torch.ones(2, 1) * 2
        sharded_tensor_1 = torch.ones(2, 1) * 3
        # tensor([[2., 3.],
        #         [2., 3.]])
        tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()

    if rank in (0, 1):
        # tensor([[0.],
        #         [0.],
        #         [2.],
        #         [2.]])
        tensor_to_check = torch.tensor([[0], [0], [2], [2]], dtype=tensor_to_comm.dtype).cuda()
    if rank in (2, 3):
        # tensor([[1.],
        #         [1.],
        #         [3.],
        #         [3.]])
        tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda()

    tensor_to_comm.sharding_spec = sharding_spec_source
    tensor_to_comm = shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
    assert tensor_to_comm.equal(tensor_to_check)
    assert str(tensor_to_comm.sharding_spec.sharding_sequence) == str(sharding_spec_target.sharding_sequence)


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_apply():
    world_size = 4
    run_func = partial(check_apply, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
    test_apply()