import torch
import pytest
import colossalai
import torch.nn.functional as F
import torch.multiprocessing as mp
from functools import partial
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec
from colossalai.utils import get_current_device
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern


def check_cross_entropy():
    input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
    input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
    with torch.no_grad():
        input_ct.copy_(input_t)

    target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device())

    world_size = torch.distributed.get_world_size()
    pg = ProcessGroup(tp_degree=world_size)
    input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
    input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
    input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))

    output = F.cross_entropy(input_t, target)
    output_colo = F.cross_entropy(input_shard, target)
    assert torch.allclose(output_colo, output)

    output.backward()
    output_colo.backward()

    assert torch.allclose(input_t.grad, input_ct.grad)


def run_dist(rank, world_size, port):
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    check_cross_entropy()


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_loss_func(world_size):
    run_func = partial(run_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
    test_loss_func(1)