from functools import partial

import pytest
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group

import colossalai
from colossalai.nn.parallel.reducer import Reducer
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device

REDUCE_CNT = 0


def check_eq(grad, grad_clone):
    global REDUCE_CNT
    print(f'Rank{dist.get_rank()} check {REDUCE_CNT}')
    REDUCE_CNT += 1
    assert torch.allclose(grad, grad_clone)


def run_reducer():
    grads = [torch.rand(64, i + 1, device=get_current_device()) for i in range(10)]
    grads_clone = [g.clone().detach() for g in grads]
    for g in grads:
        dist.all_reduce(g)
    reducer = Reducer(bucket_size_mb=1)
    for g, g_clone in zip(grads, grads_clone):
        reducer.all_reduce_async(g_clone, _get_default_group(), partial(check_eq, g))
    reducer.flush()


def run_dist(rank, world_size, port):
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    run_reducer()


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_reducer(world_size):
    spawn(run_dist, world_size)


if __name__ == '__main__':
    test_reducer(2)