#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import socket
import time
import traceback
from functools import partial

import torch
import torch.distributed as dist

import internlm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.scheduler import SchedulerMetricHook
from internlm.core.trainer import TrainState
from internlm.initialize import initialize_distributed_env
from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex
from internlm.monitor import initialize_monitor_manager, send_alert_message
from internlm.monitor.monitor import monitor_manager as mm
from internlm.train import (
    get_train_data_loader,
    get_validation_data_loader,
    initialize_llm_profile,
    initialize_model,
    initialize_optimizer,
    load_new_batch,
    record_current_batch_training_metrics,
)
from internlm.utils.common import (
    BatchSkipper,
    get_megatron_flops,
    launch_time,
    parse_args,
)
from internlm.utils.evaluation import evaluate_on_val_dls
from internlm.utils.gputest import bench_gpu, bench_net
from internlm.utils.logger import get_logger, initialize_uniscale_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.model_checkpoint import CheckpointManager
from internlm.utils.parallel import get_parallel_log_file_name
from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler
from internlm.utils.writer import Writer

# global llm logger
logger = get_logger(__file__)


def initialize_llm_logger(start_time: str):
    """
    Initialize customed uniscale logger.

    Args:
        start_time (str): The launch time of current training job.

    Returns: The instance of uniscale logger.
    """

    uniscale_logger = initialize_uniscale_logger(
        job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name()
    )
    if uniscale_logger is not None:
        global logger
        logger = uniscale_logger

    return uniscale_logger


def main(args):
    # init setting
    skip_batches = gpc.config.data.skip_batches
    total_steps = gpc.config.data.total_steps
    valid_every = gpc.config.data.valid_every
    label_smoothing = gpc.config.loss.label_smoothing
    lr = gpc.config.adam.lr

    get_tflops_func = partial(
        get_megatron_flops,
        checkpoint=gpc.config.model.checkpoint,
        seq_len=gpc.config.SEQ_LEN,
        hidden_size=gpc.config.model.hidden_size,
        num_layers=gpc.config.model.num_layers,
        vocab_size=gpc.config.model.vocab_size,
        global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA),
        global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
        mlp_ratio=gpc.config.MLP_RATIO,
    )

    # get and broadcast current time
    current_time = launch_time()
    objs = [current_time]
    dist.broadcast_object_list(objs, src=0)
    current_time = objs[0]

    # initialize customed llm logger
    uniscale_logger = initialize_llm_logger(start_time=current_time)

    # initialize and resume train state
    train_state = TrainState(gpc.config)

    # initialize model
    model = initialize_model()

    with open(args.config, "r") as f:
        config_lines = f.readlines()
    ckpt_manager = CheckpointManager(
        ckpt_config=gpc.config.ckpt,
        model=model,
        model_config=gpc.config.model,
        model_config_file="".join(config_lines),
        feishu_address=gpc.config.alert_address,
    )

    # initialize loss function
    criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)

    # initialize the train and validation data loader
    train_dl, dataset_types = get_train_data_loader(num_worker=4)
    val_dls = get_validation_data_loader()
    train_state.init_batch_sampler(train_dl)

    # Loading model weights must be done before zero is initialized.
    ckpt_manager.try_load_model(current_time)

    optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)

    # Loading other persistent training states.
    ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)

    # initialize customed llm writer
    writer = Writer(
        job_name=gpc.config.JOB_NAME,
        launch_time=current_time,
        file_name=get_parallel_log_file_name(),
        tensorboard_folder=gpc.config.tensorboard_folder,
        resume_tb_folder=train_state.resume_tb_folder,  # resume from ckpt.
        step_count=train_state.step_count,  # resume from ckpt.
        config=config_lines,
        logger=logger,
        enable_tb=gpc.config.enable_tb,
    )

    # initialize metric for calculating accuracy and perplexity
    metric = AccPerplex(
        device=torch.cuda.current_device(),
        tp_pg=gpc.get_group(ParallelMode.TENSOR),
        dp_pg=gpc.get_group(ParallelMode.DATA),
        dataset_types=dataset_types,
    )

    # initialize trainer
    scheduler_hooks = [
        SchedulerMetricHook(
            metric=metric,
            skip=(
                gpc.is_using_pp()
                and hasattr(gpc.config.model, "num_chunks")
                and gpc.config.model.num_chunks > 1
                and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
            ),
        ),
    ]

    trainer, train_dl, _, _ = internlm.initialize_trainer(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        train_dataloader=train_dl,
        lr_scheduler=lr_scheduler,
        beta2_scheduler=beta2_scheduler,
        scheduler_hooks=scheduler_hooks,
    )

    # initialize simple memory profiler
    if args.profiling:
        memory_profiler = SimpleMemoryProfiler(
            model,
            optimizer.optim,
            log_folder=f"memory_trace/rank{gpc.get_global_rank()}_"
            + f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
            + f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}",
        )
    else:
        memory_profiler = None

    # initialize the batch skipper
    batch_skipper = BatchSkipper(skip_batches)

    trainer.train()

    # transfer the train data loader into train data iterator
    train_iter = iter(train_dl)

    with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof:
        # start iterating the train data and begin training
        for batch_count in range(train_state.batch_count, total_steps):
            if batch_count % 50 == 0:
                torch.cuda.empty_cache()
                bench_gpu()
                bench_net()

            start_time = time.time()
            timer("one-batch").start()

            # load batch data
            batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)

            # record the consumed samples in training
            train_state.batch_count = batch_count
            train_state.num_consumed_samples_in_epoch += len(batch[1])
            if batch_skipper(batch_count):  # skip this batch
                if gpc.is_rank_for_log():
                    logger.info(f"Skip batch count:`{batch_count}`...")
                timer("one-batch").stop()
                continue

            # zero the grads of parameters
            trainer.zero_grad()
            # process data
            if batch[0].get("type_ids", None) is not None:
                metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None))

            # do forward and backward
            timer("fwd-bwd").start()

            _, _, loss = trainer.execute_schedule(
                batch, forward_only=False, return_loss=True, return_output_label=False
            )
            timer("fwd-bwd").stop()

            # update parameters, and returns (success_update, grad_norm)
            trainer_result = trainer.step()
            assert trainer_result is not None

            success_update, grad_norm_groups = trainer_result
            if success_update:  # update parameters successfully
                train_state.step_count += 1
            else:
                train_state.inf_nan_skip_batches += 1  # record the amount of updating parameters unsuccessfully.
                if -1 in grad_norm_groups.values() and gpc.is_rank_for_log():  # -1 encodes a specific failure case
                    logger.warning(f"Warning: skip parameter update at step {batch_count}.")
                    send_alert_message(
                        address=gpc.config.alert_address,
                        message=f"Warning: skip parameter update at step {batch_count}.",
                    )

            # calculate and record the training metrics, eg. loss, accuracy and so on.
            record_current_batch_training_metrics(
                get_tflops_func=get_tflops_func,
                logger=logger,
                writer=writer,
                success_update=success_update,
                batch_count=batch_count,
                batch=batch,
                train_state=train_state,
                optimizer=optimizer,
                beta2_scheduler=beta2_scheduler,
                trainer=trainer,
                start_time=start_time,
                loss=loss,
                grad_norm=grad_norm_groups,
                metric=metric,
                update_panel=uniscale_logger is not None,
            )

            timer("one-batch").stop()

            # evaluate on validation data loaders
            if valid_every > 0 and train_state.step_count % valid_every == 0:
                evaluate_on_val_dls(
                    trainer=trainer,
                    val_dls=val_dls,
                    writer=writer,
                    logger=logger,
                    step_count=train_state.step_count,
                    update_panel=uniscale_logger is not None,
                )

            # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
            # # save batch sampler that tracks the true consumed samples
            now_break = ckpt_manager.try_save_checkpoint(train_state)
            if now_break:
                break

            if memory_profiler is not None:
                memory_profiler.step()

            if batch_count % 2 == 0:
                prof.step()

    ckpt_manager.wait_async_upload_finish()


if __name__ == "__main__":
    args = parse_args()
    hostname = socket.gethostname()

    # initialize distributed environment
    initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
    assert hasattr(gpc, "config") and gpc.config is not None

    # initialize monitor manager context
    with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address):
        try:
            main(args)
        except Exception:
            logger.error(
                f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}",
            )
            mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())