import os
import random
from typing import Callable, Type

import numpy as np
import pytest
import torch
import torch.distributed as dist

import colossalai
from colossalai.nn.parallel import ColoDDP
from colossalai.tensor import ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager


def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def init_ddp(module: torch.nn.Module) -> ColoDDP:
    pg = ProcessGroup()
    return ColoDDP(module, process_group=pg)


def init_ddpv2(module: torch.nn.Module) -> ZeroDDP:
    chunk_config, *_ = search_chunk_configuration(module, 4, 1024)
    chunk_manager = ChunkManager(chunk_config)
    gemini_manager = GeminiManager('cuda', chunk_manager)
    return ZeroDDP(module, gemini_manager)


class Net(torch.nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.fc1 = torch.nn.Linear(3, 3, bias=False)
        self.fc2 = torch.nn.Linear(3, 1, bias=False)

    def forward(self, x):
        return self.fc2(self.fc1(x))


def run_fwd_bwd(ddp_cls: Type[ColoDDP], init_ddp_func: Callable[[torch.nn.Module], ColoDDP]):
    with ColoInitContext(device=get_current_device()):
        model = Net().cuda()
    w1 = model.fc1.weight
    w2 = model.fc2.weight
    ddp_cls.set_params_to_ignore([w2])
    model = init_ddp_func(model)
    x = torch.rand(2, 3, device=get_current_device())
    logits = model(x)
    loss = torch.sum(logits)
    model.backward(loss)

    if ddp_cls is ZeroDDP:
        w1s_grad = w1
    else:
        w1s_grad = w1.grad

    w1_grads = [torch.empty_like(w1) for _ in range(dist.get_world_size())]
    dist.all_gather(w1_grads, w1s_grad)
    assert torch.equal(w1_grads[0], w1_grads[1])
    w2_grads = [torch.empty_like(w2) for _ in range(dist.get_world_size())]
    dist.all_gather(w2_grads, w2.grad)
    assert not torch.equal(w2_grads[0], w2_grads[1])


def run_dist(rank, world_size, port):
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    set_seed(dist.get_rank())
    run_fwd_bwd(ColoDDP, init_ddp)
    run_fwd_bwd(ZeroDDP, init_ddpv2)


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_ddp_ignore_params(world_size):
    spawn(run_dist, world_size)


if __name__ == '__main__':
    test_ddp_ignore_params(2)