import torch

from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, is_distributed_tensor, to_global
from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn


def check_padded_tensor(rank, world_size, port):
    disable_existing_loggers()
    launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
    original_tensor = torch.rand(32, 64).to("cuda")

    device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
    target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
    d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec)

    padded_tensor = to_padded_tensor(d_tensor, current_length=64, padding_dim=0)
    assert padded_tensor.dist_layout == d_tensor.dist_layout

    tensor_copy = padded_tensor.clone()
    assert is_padded_tensor(tensor_copy)
    assert is_distributed_tensor(tensor_copy)

    tensor_detached = padded_tensor.detach()
    assert is_padded_tensor(tensor_detached)
    assert is_distributed_tensor(tensor_detached)

    unpadded_tensor = to_unpadded_tensor(padded_tensor)
    assert unpadded_tensor.shape == d_tensor.shape
    assert is_distributed_tensor(unpadded_tensor)

    global_tensor = to_global(unpadded_tensor)
    assert global_tensor.shape == original_tensor.shape


@rerun_if_address_is_in_use()
def test_padded_tensor():
    world_size = 4
    spawn(check_padded_tensor, world_size)


if __name__ == "__main__":
    test_padded_tensor()