#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import argparse
import os
import pprint
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from colossalai.amp import AMP_TYPE, convert_to_amp
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.engine.gradient_accumulation import accumulate_gradient
from colossalai.engine.schedule import (
    InterleavedPipelineSchedule,
    NonPipelineSchedule,
    PipelineSchedule,
    get_tensor_shape,
)
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param
from colossalai.utils.moe import sync_moe_model_param
from colossalai.zero.legacy import ShardedOptimizerV2, convert_to_zero_v2
from colossalai.zero.legacy.gemini.ophooks import BaseOpHook


def get_default_parser():
    """Reads user command line and uses an argument parser to parse the input arguments.
    Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.

    Returns:
       Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, help='path to the config file')
    parser.add_argument('--host', type=str, help='the master address for distributed training')
    parser.add_argument('--port', type=int, help='the master port for distributed training')
    parser.add_argument('--world_size', type=int, help='world size for distributed training')
    parser.add_argument('--rank', type=int, help='rank for the default process group')
    parser.add_argument('--local_rank', type=int, help='local rank on the node')
    parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication')
    return parser


def launch(config: Union[str, Path, Config, Dict],
           rank: int,
           world_size: int,
           host: str,
           port: int,
           backend: str = 'nccl',
           local_rank: int = None,
           seed: int = 1024,
           verbose: bool = True):
    """This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
    arguments are not given. Then initialize and set distributed environment by calling global_context's functions.

    Args:
        config (Union[str, dict, Config]): Config file or config file path are both acceptable
        rank (int): Rank for the default process group
        world_size (int): World size of the default process group
        host (str): The master address for distributed training
        port (str): The master port for distributed training
        backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
        local_rank (int, optional):
            Rank for the process on the node and is used to set the default CUDA device,
            defaults to None. If local_rank = None, the default device ordinal will be calculated automatically.
        seed (int, optional): Specified random seed for every process. Defaults to 1024.
        verbose (bool, optional): Whether to print logs. Defaults to True.

    Raises:
        Exception: Raise exception when config type is wrong
    """
    gpc.verbose = verbose

    # set config
    assert isinstance(config, (Config, str, Path, dict)), \
        f'expected argument config to be Config, str or Path, but got {type(config)}'
    if not isinstance(config, Config) and isinstance(config, dict):
        config = Config(config)
    if isinstance(config, (str, Path)):
        config = Config.from_file(config)
    gpc.load_config(config)

    # init default process group
    gpc.init_global_dist(rank, world_size, backend, host, port)

    # init process groups for different parallel modes from config
    gpc.init_parallel_groups()

    # set cuda device
    if torch.cuda.is_available():
        # if local rank is not given, calculate automatically
        gpc.set_device(local_rank)

    # set the number of processes running on the same node
    gpc.detect_num_processes_on_current_node()

    gpc.set_seed(seed)

    if verbose:
        logger = get_dist_logger()
        logger.info(
            f'Distributed environment is initialized, '
            f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
            f'tensor parallel size: {gpc.tensor_parallel_size}',
            ranks=[0])


def launch_from_slurm(config: Union[str, Path, Config, Dict],
                      host: str,
                      port: int,
                      backend: str = 'nccl',
                      seed: int = 1024,
                      verbose: bool = True):
    """A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables
    set by SLURM

    Args:
        config (Union[str, dict, Config]): Config file or config file path are both acceptable
        host (str): The master address for distributed training
        port (str): The master port for distributed training
        backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
        seed (int, optional): Specified random seed for every process. Defaults to 1024.
        verbose (bool, optional): Whether to print logs. Defaults to True.
    """
    try:
        rank = int(os.environ['SLURM_PROCID'])
        world_size = int(os.environ['SLURM_NPROCS'])
    except KeyError as e:
        raise RuntimeError(
            f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM"
        )

    launch(config=config,
           rank=rank,
           world_size=world_size,
           host=host,
           port=port,
           backend=backend,
           seed=seed,
           verbose=verbose)


def launch_from_openmpi(config: Union[str, Path, Config, Dict],
                        host: str,
                        port: int,
                        backend: str = 'nccl',
                        seed: int = 1024,
                        verbose: bool = True):
    """A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables
    set by OpenMPI

    Args:
        config (Union[str, dict, Config]): Config file or config file path are both acceptable
        host (str): The master address for distributed training
        port (str): The master port for distributed training
        backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
        seed (int, optional): Specified random seed for every process. Defaults to 1024.
        verbose (bool, optional): Whether to print logs. Defaults to True.
    """
    try:
        rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
        local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
        world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
    except KeyError as e:
        raise RuntimeError(
            f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI"
        )

    launch(config=config,
           local_rank=local_rank,
           rank=rank,
           world_size=world_size,
           host=host,
           port=port,
           backend=backend,
           seed=seed,
           verbose=verbose)


def launch_from_torch(config: Union[str, Path, Config, Dict],
                      backend: str = 'nccl',
                      seed: int = 1024,
                      verbose: bool = True):
    """A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
    from the environment variables set by PyTorch

    Args:
        config (Union[str, dict, Config]): Config file or config file path are both acceptable
        backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
        seed (int, optional): Specified random seed for every process. Defaults to 1024.
        verbose (bool, optional): Whether to print logs. Defaults to True.
    """
    try:
        rank = int(os.environ['RANK'])
        local_rank = int(os.environ['LOCAL_RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
        host = os.environ['MASTER_ADDR']
        port = int(os.environ['MASTER_PORT'])
    except KeyError as e:
        raise RuntimeError(
            f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
        )

    launch(config=config,
           local_rank=local_rank,
           rank=rank,
           world_size=world_size,
           host=host,
           port=port,
           backend=backend,
           seed=seed,
           verbose=verbose)


def initialize(model: nn.Module,
               optimizer: Optimizer,
               criterion: Optional[_Loss] = None,
               train_dataloader: Optional[Iterable] = None,
               test_dataloader: Optional[Iterable] = None,
               lr_scheduler: Optional[_LRScheduler] = None,
               ophooks: Optional[List[BaseOpHook]] = None,
               verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
    """Core function to wrap the essential training components with our functionality based on the config which is
    loaded into gpc.config.

    Args:
        model (:class:`torch.nn.Module` or Callbale): Your model instance or a function to build the model.
        optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`):
            Your optimizer instance.
        criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
        train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
        test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
        lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
        verbose (bool, optional): Whether to print logs.

    Returns:
        Tuple (engine, train_dataloader, test_dataloader, lr_scheduler):
            A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)``
            where only ``engine`` could not be None.
    """
    # get logger
    logger = get_dist_logger()
    gpc.verbose = verbose

    # get config from gpc
    config = gpc.config

    # print config
    if verbose:
        logger.info(
            f"\n========== Your Config ========\n"
            f"{pprint.pformat(gpc.config)}\n"
            f"================================\n",
            ranks=[0])

    # cudnn
    cudnn_benchmark = config.get('cudnn_benchmark', False)
    cudnn_deterministic = config.get('cudnn_deterministic', False)
    torch.backends.cudnn.benchmark = cudnn_benchmark
    torch.backends.cudnn.deterministic = cudnn_deterministic
    if verbose:
        logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])

    # zero
    use_zero = hasattr(gpc.config, 'zero')
    if use_zero:
        zero_cfg = gpc.config.get('zero', None)
        if zero_cfg is not None:
            cfg_ = zero_cfg.copy()
        else:
            cfg_ = {}
        optimizer_config = zero_cfg.get('optimizer_config', None)
        model_config = zero_cfg.get('model_config', None)
        model, optimizer = convert_to_zero_v2(model,
                                              optimizer,
                                              model_config=model_config,
                                              optimizer_config=optimizer_config)

        logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
    else:
        if isinstance(model, nn.Module):
            # first sync model across dp ranks
            model.to(get_current_device())
        elif isinstance(model, Callable):
            model = model().to(get_current_device())

        # optimizer maybe a optimizer_cls
        if isinstance(optimizer, Callable):
            optimizer = optimizer(model.parameters())
            logger.warning("Initializing an non ZeRO model with optimizer class")

    if not use_zero:
        if is_using_sequence():
            sync_model_param(model, ParallelMode.SEQUENCE_DP)
        elif MOE_CONTEXT.is_initialized:
            sync_moe_model_param(model)
        elif is_using_ddp():
            sync_model_param(model, ParallelMode.DATA)
    else:
        logger.warning(
            "The parameters of models is not automatically synchronized.\n"
            "Please make sure that all parameters are the same in data parallel group.",
            ranks=[0])

    # check amp and zero
    fp16_cfg = gpc.config.get('fp16', None)

    if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero:
        raise ConfigException(
            "It is not allowed to set fp16 and zero configuration in your config file at the same time")

    # clip grad norm
    clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)

    # initialize amp
    amp_mode = None
    if fp16_cfg is not None and fp16_cfg.mode is not None:
        cfg_ = fp16_cfg.copy()
        amp_mode = cfg_.pop('mode')
        if is_using_pp():
            assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
        if amp_mode == AMP_TYPE.NAIVE:
            cfg_['clip_grad_norm'] = clip_grad_norm
        model, optimizer, criterion = convert_to_amp(model=model,
                                                     optimizer=optimizer,
                                                     criterion=criterion,
                                                     mode=amp_mode,
                                                     amp_config=cfg_)

    # get torch ddp config
    torch_ddp_cfg = gpc.config.get('torch_ddp', dict())

    # gradient handler
    gradient_handler_cfg = gpc.config.get('gradient_handler', None)
    if gradient_handler_cfg is None:
        # if gradient handler is not specified in the configuration file,
        # check in the following order
        # 1. if optimizer is ZERO, then use zero grad handler
        # 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
        # 3. if using pipeline and dp size larger than 1, use data parallel grad handler
        if isinstance(optimizer, ShardedOptimizerV2):
            gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
            if verbose:
                logger.info(
                    "Training with zero is detected, ZeROGradientHandler is automatically "
                    "added even though not specified in the configuration",
                    ranks=[0])
        elif is_using_ddp() and MOE_CONTEXT.is_initialized:
            gradient_handler_cfg = [dict(type='MoeGradientHandler')]
            if verbose:
                logger.info(
                    "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
                    "added even though not specified in the configuration",
                    ranks=[0])
        elif is_using_sequence():
            model = DDP(model,
                        process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
                        device_ids=[torch.cuda.current_device()],
                        **torch_ddp_cfg)
            if verbose:
                logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism',
                            ranks=[0])
        elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
            model = DDP(model,
                        process_group=gpc.get_group(ParallelMode.DATA),
                        device_ids=[torch.cuda.current_device()],
                        **torch_ddp_cfg)
            if verbose:
                logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
        elif is_using_ddp():
            gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
            if verbose:
                logger.info(
                    "Data parallel training is detected when using pipeline parallel, "
                    "DataParallelGradientHandler is automatically "
                    "added even though not specified in the configuration",
                    ranks=[0])
        # add pipeline parallel gradient handler, if pipeline shared module is detected
        for param in model.parameters():
            if getattr(param, 'pipeline_shared_module_pg', None) is not None:
                if gradient_handler_cfg is None:
                    gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')]
                else:
                    gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler'))
                if verbose:
                    logger.info(
                        "pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically "
                        "added even though not specified in the configuration",
                        ranks=[0])
                break
    else:
        if not isinstance(gradient_handler_cfg, list):
            raise ConfigException(
                f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}"
            )

    # turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time
    # to avoid duplicated buffer synchronization
    if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
        model.module.sync_buffer = False

    # initialize schedule for engine
    if is_using_pp():
        tensor_shape = get_tensor_shape()
        use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
        if gpc.is_initialized(ParallelMode.PARALLEL_1D):
            scatter_gather = True
        else:
            scatter_gather = False
        if use_interleaved:
            if isinstance(model, nn.Sequential):
                model = nn.ModuleList([model])
            schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
                                                   gpc.config.model.num_chunks,
                                                   tensor_shape=tensor_shape,
                                                   scatter_gather_tensors=scatter_gather)
        else:
            schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
                                        tensor_shape=tensor_shape,
                                        scatter_gather_tensors=scatter_gather)
    else:
        schedule = NonPipelineSchedule()

    if gradient_handler_cfg is None:
        gradient_handlers = None
        if verbose and not isinstance(model, DDP):
            logger.warning(
                "No PyTorch DDP or gradient handler is set up, please make sure you do not need "
                "to all-reduce the gradients after a training step.",
                ranks=[0])
    else:
        gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]

    # check if optimizer is ColossalaiOptimizer
    if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizerV2)):
        optimizer = ColossalaiOptimizer(optim=optimizer)

    # gradient accumulation
    grad_accum_size = gpc.config.get('gradient_accumulation', None)
    if grad_accum_size is not None:
        optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(
            model=model,
            optimizer=optimizer,
            dataloader=train_dataloader,
            accumulate_size=grad_accum_size,
            gradient_handlers=gradient_handlers,
            lr_scheduler=lr_scheduler)
    engine = Engine(model=model,
                    optimizer=optimizer,
                    criterion=criterion,
                    gradient_handlers=gradient_handlers,
                    clip_grad_norm=clip_grad_norm,
                    ophook_list=ophooks,
                    schedule=schedule)

    return engine, train_dataloader, test_dataloader, lr_scheduler