import argparse

import torch
from benchmark_utils import benchmark
from torch.utils.data import DataLoader, Dataset
from transformers import (
    AlbertConfig,
    AlbertForSequenceClassification,
    BertConfig,
    BertForSequenceClassification,
    get_linear_schedule_with_warmup,
)

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam

# ==============================
# Prepare Hyperparameters
# ==============================
NUM_EPOCHS = 3
BATCH_SIZE = 32
LEARNING_RATE = 2.4e-5
WEIGHT_DECAY = 0.01
WARMUP_FRACTION = 0.1
SEQ_LEN = 512
VOCAB_SIZE = 1000
NUM_LABELS = 10
DATASET_LEN = 1000


class RandintDataset(Dataset):

    def __init__(self, dataset_length: int, sequence_length: int, vocab_size: int, n_class: int):

        self._sequence_length = sequence_length
        self._vocab_size = vocab_size
        self._n_class = n_class
        self._dataset_length = dataset_length
        self._datas = torch.randint(
            low=0,
            high=self._vocab_size,
            size=(self._dataset_length, self._sequence_length,),
            dtype=torch.long,
        )
        self._labels = torch.randint(low=0, high=self._n_class, size=(self._dataset_length, 1), dtype=torch.long) 

    def __len__(self):
        return self._dataset_length

    def __getitem__(self, idx):
        return self._datas[idx], self._labels[idx]


def main():
    # ==============================
    # Parse Arguments
    # ==============================
    parser = argparse.ArgumentParser()
    parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run")
    parser.add_argument('-p',
                        '--plugin',
                        type=str,
                        default='torch_ddp',
                        choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
                        help="plugin to use")
    parser.add_argument(
        "--model_type",
        type=str,
        default="bert",
        help="bert or albert",
    )

    args = parser.parse_args()

    # ==============================
    # Launch Distributed Environment
    # ==============================
    colossalai.launch_from_torch(config={}, seed=42)
    coordinator = DistCoordinator()

    # local_batch_size = BATCH_SIZE // coordinator.world_size
    lr = LEARNING_RATE * coordinator.world_size

    # ==============================
    # Instantiate Plugin and Booster
    # ==============================
    booster_kwargs = {}
    if args.plugin == 'torch_ddp_fp16':
        booster_kwargs['mixed_precision'] = 'fp16'
    if args.plugin.startswith('torch_ddp'):
        plugin = TorchDDPPlugin()
    elif args.plugin == 'gemini':
        plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
    elif args.plugin == 'low_level_zero':
        plugin = LowLevelZeroPlugin(initial_scale=2**5)

    booster = Booster(plugin=plugin, **booster_kwargs)

    # ==============================
    # Prepare Dataloader
    # ==============================

    train_dataset = RandintDataset(dataset_length=DATASET_LEN,
                                   sequence_length=SEQ_LEN,
                                   vocab_size=VOCAB_SIZE,
                                   n_class=NUM_LABELS)
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)

    # ====================================
    # Prepare model, optimizer
    # ====================================
    # bert pretrained model

    if args.model_type == "bert":
        cfg = BertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS)
        model = BertForSequenceClassification(cfg)
    elif args.model_type == "albert":
        cfg = AlbertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS)
        model = AlbertForSequenceClassification(cfg)
    else:
        raise RuntimeError

    # optimizer
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": WEIGHT_DECAY,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

    optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)

    # lr scheduler
    total_steps = len(train_dataloader) * NUM_EPOCHS
    num_warmup_steps = int(WARMUP_FRACTION * total_steps)
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=total_steps,
    )

    # criterion
    criterion = lambda inputs: inputs[0]

    # ==============================
    # Boost with ColossalAI
    # ==============================
    model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler)

    # ==============================
    # Benchmark model
    # ==============================

    results = benchmark(model,
                        booster,
                        optimizer,
                        lr_scheduler,
                        train_dataloader,
                        criterion=criterion,
                        epoch_num=NUM_EPOCHS)

    coordinator.print_on_master(results)


if __name__ == '__main__':
    main()