import os
import shutil
from copy import deepcopy

import pytest
import torch
import torch.distributed as dist
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR

import colossalai
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs


def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)


def init_1d_col_linear(weight, pg):
    spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)


def init_1d_row_embedding(weight, pg):
    spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)


def init_1d_col_embedding(weight, pg):
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)


def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    for name, p in model.named_parameters():
        if not isinstance(p, ColoTensor):
            continue
        if 'embed' in name and 'weight' in name:
            init_1d_col_embedding(p, pg)
        if 'proj1' in name and ('weight' in name or 'bias' in name):
            init_1d_col_linear(p, pg)
        if 'proj2' in name and 'weight' in name:
            init_1d_row_linear(p, pg)
        if 'classifier' in name and ('weight' in name or 'bias' in name):
            init_1d_col_linear(p, pg)


def check_param_equal(model, torch_model):
    for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
        assert torch.all(p.data == tp.data), "{} went wrong.\n {} vs {}\n{}".format(n, p, tp, p.shape)


def remove(path):
    """ param <path> could either be relative or absolute. """
    if os.path.isfile(path) or os.path.islink(path):
        os.remove(path)
    elif os.path.isdir(path):
        shutil.rmtree(path)
    else:
        raise ValueError("file {} is not a file or dir.".format(path))


def compare_optims(optim1, optim2):
    state1 = optim1.state_dict()['state']
    state2 = optim2.state_dict()['state']
    for k, p1 in state1.items():
        if k not in state2:
            continue
        p2 = state2[k]
        for n, t1 in p1.items():
            if n not in p2:
                continue
            t2 = p2[n]
            if isinstance(t1, ColoTensor):
                assert isinstance(t2, ColoTensor)
                assert torch.allclose(t1, t2, rtol=0, atol=0)


def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()

    # set_seed(1)
    with ColoInitContext(device=get_current_device()):
        model = model_builder(checkpoint=True)

    if use_mp_reload:
        if 'bert' == model_name:
            for name, p in model.named_parameters():
                if not isinstance(p, ColoTensor):
                    continue
                # num_class = type_vocab_size = 2 | (8, 2)
                if 'classifier' in name and 'weight' in name:
                    init_1d_row_linear(p, pg)
                # num_class = vocab_size = 30524 | (30524, 8)
                elif 'word_embeddings' in name and 'weight' in name:
                    init_1d_row_embedding(p, pg)
                # num_class = seq_len = 512 | (512, 8)
                elif 'position_embeddings' in name and 'weight' in name:
                    init_1d_row_embedding(p, pg)
                # num_class = type_vocab_size = 2 | (2, 8)
                elif 'token_type_embeddings' in name and 'weight' in name:
                    init_1d_col_embedding(p, pg)
                elif p.process_group.tp_world_size() == 1:
                    p.set_process_group(pg)
        elif "simple_net" == model_name:
            init_spec_func(model, pg)

    model_reload = deepcopy(model)
    model = model.cuda()
    model.eval()

    model_reload = model_reload.cuda()
    model_reload.eval()

    opt_class = torch.optim.Adam
    colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
    colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))

    for i, (data, label) in enumerate(train_dataloader):

        # Zero grad
        colo_optimizer.zero_grad()
        colo_optimizer_reload.zero_grad()

        data = data.to(get_current_device())
        label = label.to(get_current_device())

        dist.broadcast(data, pg.tp_rank_list()[0], pg.tp_process_group())
        dist.broadcast(label, pg.tp_rank_list()[0], pg.tp_process_group())

        # Bcast rank0 data to all processes
        if criterion:
            output = model(data)
            output_reload = model_reload(data)
            loss = criterion(output, label)
            loss_reload = criterion(output_reload, label)
        else:
            loss = model(data, label)
            loss_reload = model_reload(data, label)

        loss.backward()
        loss_reload.backward()

        colo_optimizer.step()
        colo_optimizer_reload.step()

        if i > 2:
            break

    if not os.path.isdir('./checkpoint') and rank == 0:
        os.mkdir('./checkpoint')
    dist.barrier()

    save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
    load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)

    check_param_equal(model, model_reload)
    compare_optims(colo_optimizer, colo_optimizer_reload)

    if rank == 0:
        remove('./checkpoint')
    dist.barrier()


def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    pg = ProcessGroup(tp_degree=world_size)

    # the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
    for model_name in ['bert']:
        _run_checkpoint(model_name,
                        init_1d_row_for_linear_weight_spec,
                        use_ddp,
                        use_mp_reload,
                        test_scheduler=test_scheduler,
                        pg=pg)


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('use_ddp', [False])
@pytest.mark.parametrize('use_mp_reload', [True, False])
# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
@rerun_if_address_is_in_use()
def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None):
    spawn(run_dist, world_size, use_ddp=use_ddp, use_mp_reload=use_mp_reload, test_scheduler=test_scheduler)


if __name__ == '__main__':
    test_checkpoint(2, use_ddp=False, use_mp_reload=True, test_scheduler="torch_cosine")