import argparse
import os
from pathlib import Path

import torch
import torch.distributed as dist
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.optim import Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from tqdm import tqdm

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device

# ==============================
# Prepare Hyperparameters
# ==============================
NUM_EPOCHS = 80
LEARNING_RATE = 1e-3


def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase):
    # transform
    transform_train = transforms.Compose(
        [transforms.Pad(4),
         transforms.RandomHorizontalFlip(),
         transforms.RandomCrop(32),
         transforms.ToTensor()])
    transform_test = transforms.ToTensor()

    # CIFAR-10 dataset
    data_path = os.environ.get('DATA', './data')
    with coordinator.priority_execution():
        train_dataset = torchvision.datasets.CIFAR10(root=data_path,
                                                     train=True,
                                                     transform=transform_train,
                                                     download=True)
        test_dataset = torchvision.datasets.CIFAR10(root=data_path,
                                                    train=False,
                                                    transform=transform_test,
                                                    download=True)

    # Data loader
    train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
    return train_dataloader, test_dataloader


@torch.no_grad()
def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float:
    model.eval()
    correct = torch.zeros(1, dtype=torch.int64, device=get_current_device())
    total = torch.zeros(1, dtype=torch.int64, device=get_current_device())
    for images, labels in test_dataloader:
        images = images.cuda()
        labels = labels.cuda()
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    dist.all_reduce(correct)
    dist.all_reduce(total)
    accuracy = correct.item() / total.item()
    if coordinator.is_master():
        print(f'Accuracy of the model on the test images: {accuracy * 100:.2f} %')
    return accuracy


def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: nn.Module, train_dataloader: DataLoader,
                booster: Booster, coordinator: DistCoordinator):
    model.train()
    with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar:
        for images, labels in pbar:
            images = images.cuda()
            labels = labels.cuda()
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            booster.backward(loss, optimizer)
            optimizer.step()
            optimizer.zero_grad()

            # Print log info
            pbar.set_postfix({'loss': loss.item()})


def main():
    # ==============================
    # Parse Arguments
    # ==============================
    parser = argparse.ArgumentParser()
    # FIXME(ver217): gemini is not supported resnet now
    parser.add_argument('-p',
                        '--plugin',
                        type=str,
                        default='torch_ddp',
                        choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero'],
                        help="plugin to use")
    parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint")
    parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory")
    parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint")
    parser.add_argument('--target_acc',
                        type=float,
                        default=None,
                        help="target accuracy. Raise exception if not reached")
    args = parser.parse_args()

    # ==============================
    # Prepare Checkpoint Directory
    # ==============================
    if args.interval > 0:
        Path(args.checkpoint).mkdir(parents=True, exist_ok=True)

    # ==============================
    # Launch Distributed Environment
    # ==============================
    colossalai.launch_from_torch(config={})
    coordinator = DistCoordinator()

    # update the learning rate with linear scaling
    # old_gpu_num / old_lr = new_gpu_num / new_lr
    global LEARNING_RATE
    LEARNING_RATE *= coordinator.world_size

    # ==============================
    # Instantiate Plugin and Booster
    # ==============================
    booster_kwargs = {}
    if args.plugin == 'torch_ddp_fp16':
        booster_kwargs['mixed_precision'] = 'fp16'
    if args.plugin.startswith('torch_ddp'):
        plugin = TorchDDPPlugin()
    elif args.plugin == 'gemini':
        plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
    elif args.plugin == 'low_level_zero':
        plugin = LowLevelZeroPlugin(initial_scale=2**5)

    booster = Booster(plugin=plugin, **booster_kwargs)

    # ==============================
    # Prepare Dataloader
    # ==============================
    train_dataloader, test_dataloader = build_dataloader(100, coordinator, plugin)

    # ====================================
    # Prepare model, optimizer, criterion
    # ====================================
    # resent50
    model = torchvision.models.resnet18(num_classes=10)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE)

    # lr scheduler
    lr_scheduler = MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=1 / 3)

    # ==============================
    # Boost with ColossalAI
    # ==============================
    model, optimizer, criterion, _, lr_scheduler = booster.boost(model,
                                                                 optimizer,
                                                                 criterion=criterion,
                                                                 lr_scheduler=lr_scheduler)

    # ==============================
    # Resume from checkpoint
    # ==============================
    if args.resume >= 0:
        booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth')
        booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth')
        booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth')

    # ==============================
    # Train model
    # ==============================
    start_epoch = args.resume if args.resume >= 0 else 0
    for epoch in range(start_epoch, NUM_EPOCHS):
        train_epoch(epoch, model, optimizer, criterion, train_dataloader, booster, coordinator)
        lr_scheduler.step()

        # save checkpoint
        if args.interval > 0 and (epoch + 1) % args.interval == 0:
            booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth')
            booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth')
            booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth')

    accuracy = evaluate(model, test_dataloader, coordinator)
    if args.target_acc is not None:
        assert accuracy >= args.target_acc, f'Accuracy {accuracy} is lower than target accuracy {args.target_acc}'


if __name__ == '__main__':
    main()