import copy
from functools import partial

import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close

import colossalai
from colossalai.testing.random import seed_all
from colossalai.utils import free_port
from colossalai.zero import LowLevelZeroOptimizer


class MlpModel(nn.Module):

    def __init__(self):
        super(MlpModel, self).__init__()
        self.linear1 = nn.Linear(128, 256)
        self.linear2 = nn.Linear(256, 512)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x


def half_close(a, b, loose=False):
    rtol = None
    atol = None
    if loose:
        rtol = 5e-2
        atol = 5e-4

    a = a.detach().half()
    b = b.detach().half()

    assert_close(a, b, rtol=rtol, atol=atol)


def exam_zero_1_2():
    """
    In this test, we want to test whether zero stage 1 and 2
    deliver the same numerical results despite different communication
    pattern

    we use these prefixes to differentiate the zero stage
    oss: partition optimizer states
    pg: partition gradients and optimizer states

    """
    local_rank = torch.distributed.get_rank()
    seed_all(2001)

    # create model
    zero1_model = MlpModel().cuda()
    zero2_model = copy.deepcopy(zero1_model)

    # create optimizer
    zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
    zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
    zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
                                            overlap_communication=True,
                                            initial_scale=128,
                                            verbose=True)
    zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
                                            overlap_communication=True,
                                            partition_grad=True,
                                            initial_scale=128)
    # create data
    seed_all(2001 + local_rank)
    input_data = torch.randn(32, 128).cuda()

    zero1_output = zero1_model(input_data)
    zero2_output = zero2_model(input_data)
    assert torch.equal(zero1_output, zero2_output)

    # zero-dp backward
    zero1_optimizer.backward(zero1_output.mean().float(), sync_grad=False)
    zero2_optimizer.backward(zero2_output.mean().float(), sync_grad=False)

    for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
        if z2p.grad is not None:
            # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
            assert torch.equal(z1p.grad, z2p.grad)

    zero1_optimizer._sync_grad()
    zero2_optimizer._sync_grad()

    # step
    zero1_optimizer.step()
    zero2_optimizer.step()

    # check updated param
    for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
        assert torch.equal(z1p.data, z2p.data)


def exam_zero_1_torch_ddp():
    """
    In this test, two pairs of model and optimizers are created.
    1. zero: use sharded optimizer and fp16 parameters
    2. torch: use torch DDP and fp32 parameters

    We feed these two sets of models with the same input and check if the
    differences in model output and updated parameters are within tolerance.
    """
    local_rank = torch.distributed.get_rank()
    seed_all(1453)

    # create models
    zero_model = MlpModel()
    torch_model = copy.deepcopy(zero_model)

    zero_model = zero_model.cuda().half()
    torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
    torch_model = torch_model.cuda()

    # for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
    #    half_close(p.data, z1p.data)

    # create optimizer
    zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)

    # we only test stage 1 here
    # in `check_sharded_param_consistency.py`, we will test whether
    # level 1 and 2 will produce exactly the same results
    zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
                                           overlap_communication=True,
                                           initial_scale=1,
                                           reduce_bucket_size=262144)

    torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)

    seed_all(1453 + local_rank)
    # create
    input_data = torch.rand(32, 128).cuda()

    # zero-dp forward
    zero_output = zero_model(input_data.half())

    # torch-ddp forward
    torch_output = torch_model(input_data)
    half_close(zero_output, torch_output, loose=True)

    # zero-dp backward
    zero_optimizer.backward(zero_output.mean().float(), sync_grad=False)

    # torch-ddp backward
    torch_output.mean().backward()

    # check grad
    for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
        half_close(p.grad, z1p.grad, loose=True)

    # zero-dp step
    zero_optimizer._sync_grad()
    zero_optimizer.step()

    # torch ddp step
    torch_optimizer.step()

    # check updated param
    for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
        # print(n, torch.max(torch.abs(p.data - z1p.data)))
        half_close(p.data, z1p.data, loose=True)


def run_dist(rank, world_size, port):
    colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')

    exam_zero_1_torch_ddp()
    exam_zero_1_2()


@pytest.mark.dist
def test_zero_1_2():
    world_size = 2
    run_func = partial(run_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
    test_zero_1_2()