import argparse
import functools
import os

import torch
import torch.distributed as dist
import tqdm
from model.modeling_openmoe import LlamaConfig, OpenMoeDecoderLayer, OpenMoeForCausalLM, set_openmoe_args
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from transformers.models.llama import LlamaConfig
from utils import PerformanceEvaluator, get_model_numel

from colossalai.moe.manager import MOE_MANAGER


class RandomDataset(Dataset):
    def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
        self.num_samples = num_samples
        self.max_length = max_length
        self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length))
        self.attention_mask = torch.ones_like(self.input_ids)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.input_ids[idx],
        }


def fsdp_main(rank, world_size, args):
    # initialize the process group

    # initialize the process group
    dist.init_process_group("nccl")

    MOE_MANAGER.setup(parallel=None)

    dp_size = dist.get_world_size()
    dataset = RandomDataset(
        max_length=args.seq_length,
        num_samples=args.batch_size * (args.warmup + args.active) * dp_size,
    )
    sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False)
    train_kwargs = {"batch_size": args.batch_size, "sampler": sampler}
    train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs)
    torch.cuda.set_device(rank)

    config = LlamaConfig.from_pretrained("hpcaitech/openmoe-%s" % args.model_name)
    set_openmoe_args(
        config,
        num_experts=config.num_experts,
        moe_layer_interval=config.moe_layer_interval,
        enable_load_balance=False,
        enable_kernel=False,
        enable_comm_overlap=False,
    )
    torch.set_default_dtype(torch.float16)
    model = OpenMoeForCausalLM(config)
    torch.set_default_dtype(torch.float32)
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            OpenMoeDecoderLayer,
        },
    )
    model = FSDP(
        model,
        mixed_precision=MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16,
        ),
        auto_wrap_policy=auto_wrap_policy,
        device_id=torch.cuda.current_device(),
    )
    optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5)
    model.train()

    model_numel = get_model_numel(model)
    performance_evaluator = PerformanceEvaluator(
        model_numel,
        enable_grad_checkpoint=True,
        ignore_steps=args.warmup,
        dp_world_size=dist.get_world_size(),
    )

    for step, data in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
        performance_evaluator.on_step_start(step)
        input_ids, attention_mask, labels = (
            data["input_ids"].cuda(),
            data["attention_mask"].cuda(),
            data["labels"].cuda(),
        )

        optimizer.zero_grad()
        output = model(
            input_ids=input_ids,
            labels=labels,
            attention_mask=attention_mask,
            chunk_head=False,
        )
        loss = output["loss"]
        loss.backward()
        optimizer.step()
        performance_evaluator.on_step_end(input_ids)

    performance_evaluator.on_fit_end()
    if dist.get_rank() == 0:
        print(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name",
        type=str,
        default="base",
        choices=["base", "8b"],
        help="base or 8b",
    )
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--seq_length", type=int, default=2048)
    parser.add_argument("--warmup", type=int, default=20)
    parser.add_argument("--active", type=int, default=20)
    args = parser.parse_args()

    torch.manual_seed(42)

    world_size = int(os.environ["WORLD_SIZE"])
    local_rank = int(os.environ["LOCAL_RANK"])
    fsdp_main(local_rank, world_size, args)