import argparse

import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralForCausalLM

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device


@torch.no_grad()
def get_global_loss(loss, booster):
    global_loss = loss.clone().detach()
    dist.all_reduce(tensor=global_loss, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group)
    global_loss.div_(booster.plugin.dp_size)
    return global_loss


class RandomDataset(Dataset):
    def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 100, tokenizer=None):
        self.num_samples = num_samples
        self.max_length = max_length
        self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
        self.attention_mask = torch.ones_like(self.input_ids)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.input_ids[idx],
        }


def parse_args():
    # basic settings
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name",
        type=str,
        default="mistralai/Mixtral-8x7B-v0.1",
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
    parser.add_argument(
        "--plugin",
        type=str,
        default="hybrid",
        choices=["hybrid"],
        help="Parallel methods.",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        default="./outputs",
        help="The path of your saved model after finetuning.",
    )
    parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="Batch size (per dp group) for the training dataloader.",
    )
    parser.add_argument(
        "--save_interval",
        type=int,
        default=1000,
        help=" The interval (steps) of saving checkpoints.",
    )
    parser.add_argument(
        "--precision",
        type=str,
        default="bf16",
        choices=["fp32", "bf16", "fp16"],
        help="The mixed precision training.",
    )
    parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")
    parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")

    # optim
    parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
    parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")

    # lr scheduler
    parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
    parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")

    # zero stage for all plugins
    parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
    # hybrid plugin
    parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")
    parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")
    parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")
    parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")

    # kernel
    parser.add_argument(
        "--use_kernel",
        action="store_true",
        help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
    )
    parser.add_argument(
        "--use_layernorm_kernel",
        action="store_true",
        help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
    )

    # load balance
    parser.add_argument(
        "--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."
    )
    parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")
    # communicate overlap
    parser.add_argument(
        "--comm_overlap",
        action="store_true",
        help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
    )
    # hierarchical all-to-all
    parser.add_argument(
        "--hierarchical_alltoall",
        action="store_true",
        help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.",
    )

    args = parser.parse_args()
    return args


def main():
    args = parse_args()

    # Launch ColossalAI
    colossalai.launch_from_torch(config={}, seed=args.seed)
    coordinator = DistCoordinator()

    # Set plugin
    if args.plugin == "hybrid":
        plugin = MoeHybridParallelPlugin(
            tp_size=1,
            pp_size=args.pp_size,
            ep_size=args.ep_size,
            microbatch_size=args.microbatch_size,
            custom_policy=MixtralForCausalLMPolicy(),
            enable_fused_normalization=args.use_layernorm_kernel,
            enable_jit_fused=args.use_kernel,
            precision=args.precision,
            zero_stage=args.zero_stage,
            checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
        )

    else:
        raise ValueError(f"Invalid plugin {args.plugin}")
    coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")

    # Build Mixtral model
    model = MixtralForCausalLM.from_pretrained(args.model_name)
    coordinator.print_on_master(f"Finish init model")

    # Enable gradient checkpointing
    model.gradient_checkpointing_enable()

    # Prepare tokenizer and dataloader
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    dataset = RandomDataset(num_samples=100, tokenizer=tokenizer)
    collate_fn = None
    dataloader = plugin.prepare_dataloader(
        dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
    )

    # Set optimizer
    optimizer = HybridAdam(
        model_params=model.parameters(),
        lr=args.lr,
        betas=(0.9, 0.95),
        weight_decay=args.weight_decay,
        adamw_mode=True,
    )

    # Set lr scheduler
    lr_scheduler = CosineAnnealingWarmupLR(
        optimizer=optimizer,
        total_steps=args.num_epochs * len(dataloader),
        warmup_steps=args.warmup_steps
        if args.warmup_steps is not None
        else int(args.num_epochs * len(dataloader) * 0.025),
        eta_min=0.1 * args.lr,
    )

    # Set booster
    booster = Booster(plugin=plugin)
    model, optimizer, _, dataloader, lr_scheduler = booster.boost(
        model=model,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        dataloader=dataloader,
    )
    use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
    is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
    coordinator.print_on_master(f"Finish init booster")

    # Load ckpt
    if args.load_checkpoint is not None:
        load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler)
        coordinator.print_on_master(f"Finish load optimizer")

    # Start finetuning
    coordinator.print_on_master(f"Start finetuning")
    for epoch in range(args.num_epoch):
        model.train()
        train_dataloader_iter = iter(dataloader)
        total_len = len(train_dataloader_iter)
        with tqdm(
            range(total_len),
            desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
            disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage,
        ) as pbar:
            for step in pbar:
                if use_pipeline:
                    # Forward pass
                    outputs = booster.execute_pipeline(
                        train_dataloader_iter,
                        model,
                        lambda x, y: x.loss,
                        optimizer,
                        return_loss=True,
                    )
                    # Backward and optimize
                    if is_pp_last_stage:
                        loss = outputs["loss"]
                        global_loss = get_global_loss(loss, booster)
                        if coordinator._local_rank == "0":
                            pbar.set_postfix({"Loss": global_loss.item()})
                else:
                    # Forward pass
                    data = next(train_dataloader_iter)
                    data = move_to_cuda(data, torch.cuda.current_device())
                    outputs = model(**data)
                    loss = outputs["loss"]
                    # Backward
                    booster.backward(loss, optimizer)
                    pbar.set_postfix({"loss": loss.item()})

                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

                # Apply load balance
                # if (
                #     args.load_balance
                #     and args.load_balance_interval > 0
                #     and (step + 1) % args.load_balance_interval == 0
                # ):
                #     coordinator.print_on_master(f"Apply load balance")
                #     apply_load_balance(model, optimizer)
                # save ckeckpoint
                if (step + 1) % args.save_interval == 0:
                    coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
                    save_checkpoint(
                        args.output_path,
                        booster,
                        model,
                        optimizer,
                        lr_scheduler,
                        epoch,
                        step,
                        args.batch_size,
                        coordinator,
                    )

        # save checkpoint at the end of each epochs
        booster.save_model(model, args.output_path, shard=True, size_per_shard=5120)
        coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")

    # Finish training
    coordinator.print_on_master(f"Finish training")


if __name__ == "__main__":
    main()