import argparse
import os
from pathlib import Path

import torch
from titans.utils import barrier_context
from torch.fx import GraphModule
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet50
from tqdm import tqdm

import colossalai
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.solver.options import DataloaderOption, SolverOptions
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingLR
from colossalai.utils import get_dataloader

DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute()


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-s', '--synthetic', action="store_true", help="use synthetic dataset instead of CIFAR10")
    return parser.parse_args()


def synthesize_data():
    img = torch.rand(gpc.config.BATCH_SIZE, 3, 32, 32)
    label = torch.randint(low=0, high=10, size=(gpc.config.BATCH_SIZE,))
    return img, label


def main():
    args = parse_args()
    colossalai.launch_from_torch(config='./config.py')

    logger = get_dist_logger()

    if not args.synthetic:
        with barrier_context():
            # build dataloaders
            train_dataset = CIFAR10(root=DATA_ROOT,
                                    download=True,
                                    transform=transforms.Compose([
                                        transforms.RandomCrop(size=32, padding=4),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                                             std=[0.2023, 0.1994, 0.2010]),
                                    ]))

        test_dataset = CIFAR10(root=DATA_ROOT,
                               train=False,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
                               ]))

        train_dataloader = get_dataloader(
            dataset=train_dataset,
            add_sampler=True,
            shuffle=True,
            batch_size=gpc.config.BATCH_SIZE,
            pin_memory=True,
        )

        test_dataloader = get_dataloader(
            dataset=test_dataset,
            add_sampler=True,
            batch_size=gpc.config.BATCH_SIZE,
            pin_memory=True,
        )
    else:
        train_dataloader, test_dataloader = None, None

    # initialize device mesh
    physical_mesh_id = torch.arange(0, 4)
    mesh_shape = (2, 2)
    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)

    # trace the model with meta data
    tracer = ColoTracer()
    model = resnet50(num_classes=10).cuda()
    input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
    graph = tracer.trace(root=model, meta_args=input_sample)
    gm = GraphModule(model, graph, model.__class__.__name__)
    gm.recompile()

    # prepare info for solver
    solver_options = SolverOptions(dataloader_option=DataloaderOption.DISTRIBUTED)
    strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
    strategies_constructor.build_strategies_and_cost()
    cost_graph = CostGraph(strategies_constructor.leaf_strategies)
    cost_graph.simplify_graph()
    graph_analyser = GraphAnalyser(gm)

    # solve the solution
    solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
    ret = solver.call_solver_serialized_args()
    solution = list(ret[0])
    if gpc.get_global_rank() == 0:
        for index, node in enumerate(graph.nodes):
            print(node.name, node.strategies_vector[solution[index]].name)

    # process the graph for distributed training ability
    gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
    gm = runtime_apply_pass(gm)
    gm.recompile()

    # build criterion
    criterion = torch.nn.CrossEntropyLoss()

    # optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

    # lr_scheduler
    lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS)

    for epoch in range(gpc.config.NUM_EPOCHS):
        gm.train()

        if args.synthetic:
            # if we use synthetic data
            # we assume it only has 30 steps per epoch
            num_steps = range(30)

        else:
            # we use the actual number of steps for training
            num_steps = range(len(train_dataloader))
            data_iter = iter(train_dataloader)
        progress = tqdm(num_steps)

        for _ in progress:
            if args.synthetic:
                # generate fake data
                img, label = synthesize_data()
            else:
                # get the real data
                img, label = next(data_iter)

            img = img.cuda()
            label = label.cuda()
            optimizer.zero_grad()
            output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
            train_loss = criterion(output, label)
            train_loss.backward(train_loss)
            optimizer.step()
        lr_scheduler.step()

        # run evaluation
        gm.eval()
        correct = 0
        total = 0

        if args.synthetic:
            # if we use synthetic data
            # we assume it only has 10 steps for evaluation
            num_steps = range(30)

        else:
            # we use the actual number of steps for training
            num_steps = range(len(test_dataloader))
            data_iter = iter(test_dataloader)
        progress = tqdm(num_steps)

        for _ in progress:
            if args.synthetic:
                # generate fake data
                img, label = synthesize_data()
            else:
                # get the real data
                img, label = next(data_iter)

            img = img.cuda()
            label = label.cuda()

            with torch.no_grad():
                output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
                test_loss = criterion(output, label)
            pred = torch.argmax(output, dim=-1)
            correct += torch.sum(pred == label)
            total += img.size(0)

        logger.info(
            f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}",
            ranks=[0])


if __name__ == '__main__':
    main()