#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import random
import socket
from pathlib import Path
from typing import Callable, List, Union, Dict, Optional
import functools

import torch
from torch._six import inf
from torch.nn.parameter import Parameter

try:
    import colossal_C
except:
    pass

from contextlib import contextmanager

import torch.distributed as dist
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from .multi_tensor_apply import multi_tensor_applier

from colossalai.tensor import ColoParameter, ProcessGroup
from collections import defaultdict


def print_rank_0(msg: str, logger=None):
    """Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.

    Args:
        msg (str): A string message to output.
        logger (:class:`colossalai.logging.DistributedLogger`, optional):
            The logger to record the message, defaults to None.
    """
    if gpc.get_global_rank() == 0:
        if logger is None:
            print(msg, flush=True)
        else:
            logger.info(msg)


def ensure_path_exists(filename: str):
    # ensure the path exists
    dirpath = os.path.dirname(filename)
    if not os.path.exists(dirpath):
        Path(dirpath).mkdir(parents=True, exist_ok=True)


def free_port():
    while True:
        try:
            sock = socket.socket()
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            port = random.randint(20000, 65000)
            sock.bind(('localhost', port))
            sock.close()
            return port
        except Exception:
            continue


def sync_model_param(model, parallel_mode):
    r"""Make sure data parameters are consistent during Data Parallel Mode.

    Args:
        model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
        parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel mode to be checked.

    Note:
        The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
        in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
    """
    if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
        for param in model.parameters():
            ranks = gpc.get_ranks_in_group(parallel_mode)
            dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))


def is_dp_rank_0():
    return not gpc.is_initialized(ParallelMode.DATA) or gpc.is_first_rank(ParallelMode.DATA)


def is_tp_rank_0():
    return not gpc.is_initialized(ParallelMode.TENSOR) or gpc.is_first_rank(ParallelMode.TENSOR)


def is_no_pp_or_last_stage():
    return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)


def is_using_ddp():
    return gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1


def is_using_pp():
    return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1


def is_using_sequence():
    return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1


@contextmanager
def conditional_context(context_manager, enable=True):
    if enable:
        with context_manager:
            yield
    else:
        yield


class model_branch_context(object):

    def __enter__(self):
        self.env_status = env.save()

    def __exit__(self, *exc_info):
        env.load(**self.env_status)


def is_model_parallel_parameter(p):
    return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)


def _calc_l2_norm(grads):
    norm = 0.0
    if len(grads) > 0:
        dummy_overflow_buf = torch.cuda.IntTensor([0])
        norm, _ = multi_tensor_applier(
            colossal_C.multi_tensor_l2norm,
            dummy_overflow_buf,
            [grads],
            False    # no per-parameter norm
        )
    return norm


def _calc_lp(grads, norm_type):
    norm = 0.0
    for grad in grads:
        grad_norm = torch.norm(grad, norm_type)
        norm += grad_norm**norm_type
    return norm


def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
    if torch.is_tensor(norm) and norm.device.type != 'cuda':
        norm = norm.to(torch.cuda.current_device())
    return norm


def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor:
    if isinstance(norm, float):
        norm = torch.Tensor([norm])
    if move_to_cuda:
        norm = norm.to(torch.cuda.current_device())
    return norm


# ======== Gradient Clipping =========


def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float:
    if len(params) == 0:
        return 0.0
    grads = [p.grad for p in params]
    use_cuda_kernel = grads[0].device.type == 'cuda'
    if norm_type == inf:
        local_lp = max([g.abs().max() for g in grads])
    elif norm_type == 2.0 and use_cuda_kernel:
        local_lp = _calc_l2_norm(grads)**norm_type
    else:
        local_lp = _calc_lp(grads, norm_type)
    if isinstance(local_lp, torch.Tensor):
        return local_lp.item()
    return local_lp


def _compute_buckets_lp(params: List[ColoParameter], norm_type: float) -> float:
    if len(params) == 0:
        return 0.0
    buckets: Dict[Optional[ProcessGroup], List[ColoParameter]] = defaultdict(list)
    for p in params:
        if p.is_replicate():
            buckets[None].append(p)
        else:
            buckets[p.get_process_group().tp_process_group()].append(p)
    total_lp = 0.0
    for group, bucket in buckets.items():
        local_lp = _compute_local_lp(bucket, norm_type)
        if group is not None:
            local_lp_tensor = torch.tensor([local_lp], device=torch.cuda.current_device())
            if norm_type == inf:
                dist.all_reduce(local_lp_tensor, op=dist.ReduceOp.MAX, group=group)
            else:
                dist.all_reduce(local_lp_tensor, group=group)
            local_lp = local_lp_tensor.item()
        if norm_type == inf:
            total_lp = max(total_lp, local_lp)
        else:
            total_lp += local_lp
    return total_lp


def _compute_pp_grad_lp(total_lp: float, norm_type: float) -> float:
    if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
        total_lp_tensor = torch.tensor([total_lp], device=torch.cuda.current_device())
        if norm_type == inf:
            dist.all_reduce(total_lp_tensor, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PIPELINE))
        else:
            dist.all_reduce(total_lp_tensor, group=gpc.get_group(ParallelMode.PIPELINE))
        total_lp = total_lp_tensor.item()
    return total_lp


def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float:
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    grad_dtype = None
    cpu_grad_params: List[ColoParameter] = []
    cuda_grad_params: List[ColoParameter] = []
    for p in parameters:
        if p.grad is None:
            continue
        assert isinstance(p, ColoParameter)
        if grad_dtype is None:
            grad_dtype = p.grad.dtype
        assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}'
        if p.grad.device.type == 'cuda':
            cuda_grad_params.append(p)
        else:
            cpu_grad_params.append(p)
    norm_type = float(norm_type)
    cpu_lp = _compute_buckets_lp(cpu_grad_params, norm_type)
    cuda_lp = _compute_buckets_lp(cuda_grad_params, norm_type)
    if norm_type == inf:
        total_lp = max(cpu_lp, cuda_lp)
    else:
        total_lp = cpu_lp + cuda_lp
    return _compute_pp_grad_lp(total_lp, norm_type)


def compute_grad_norm(parameters, norm_type: float = 2.0) -> float:
    norm_type = float(norm_type)
    total_norm = _compute_grad_lp(parameters, norm_type)
    if norm_type != inf:
        total_norm = total_norm**(1 / norm_type)
    return total_norm


def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None:
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1.0:
        cuda_grads: List[torch.Tensor] = []
        cpu_grads: List[torch.Tensor] = []
        if isinstance(parameters, torch.Tensor):
            parameters = [parameters]
        for p in parameters:
            if p.grad is None:
                continue
            if p.grad.device.type == 'cuda':
                cuda_grads.append(p.grad.detach())
            else:
                cpu_grads.append(p.grad.detach())
        if len(cuda_grads) > 0:
            dummy_overflow_buf = torch.cuda.IntTensor([0])
            multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], clip_coef)
        for g in cpu_grads:
            g.mul_(clip_coef)


def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0) -> float:
    total_norm = compute_grad_norm(parameters, norm_type)
    _clip_grad_norm(parameters, max_norm, total_norm)
    return total_norm


def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
    """Clips gradient norm of an iterable of parameters whose gradients are in fp32.

    This is adapted from :func:`torch.nn.utils.clip_grad.clip_grad_norm_` and
    added functionality to handle model parallel parameters.

    Note:
        the gradients are modified in place.

    Args:
        parameters (Iterable[:class:`torch.tensor`] or :class:`torch.tensor`):
            An iterable of Tensors or a single Tensor that will have gradients normalized.
        max_norm (Union[float, int]): Max norm of the gradients.
        norm_type (Union[float, int, 'inf']): Type of the used p-norm. Can be ``'inf'`` for infinity norm.

    Returns:
        float: Total norm of the parameters.
    """

    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]

    # Filter parameters based on:
    #   - grad should not be none
    #   - parameter should not be shared
    #   - should not be a replica due to tensor model parallelism
    params: List[Parameter] = []
    has_zero_shared_param: bool = False
    for param in parameters:
        if param.grad is not None:
            # Make sure the grads are in fp32
            assert param.grad.dtype == torch.float, \
                f'expected gradient to be dtype torch.float, but got {param.grad.type()}'
            if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded:
                has_zero_shared_param = True
            params.append(param)

    if len(params) == 0:
        enable_cuda_kernels = False
    else:
        enable_cuda_kernels = params[0].grad.device.type == 'cuda'
    # Norm parameters.
    max_norm = float(max_norm)
    norm_type = float(norm_type)

    # Parameters can be on CPU or CUDA
    # If parameters are on CPU, disable CUDA kernerls

    # Calculate norm.
    if norm_type == inf:
        total_norm = max(p.grad.data.abs().max() for p in params)
        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
        # Take max across all model-parallel GPUs.
        if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
            dist.all_reduce(total_norm_cuda,
                            op=dist.ReduceOp.MAX,
                            group=gpc.get_group(ParallelMode.MODEL),
                            async_op=False)
        if has_zero_shared_param:
            dist.all_reduce(total_norm_cuda,
                            op=dist.ReduceOp.MAX,
                            group=gpc.get_group(ParallelMode.DATA),
                            async_op=False)
        total_norm = total_norm_cuda[0].item()
    else:
        tensor_parallel_grads = []
        no_tensor_parallel_grads = []
        zero_sharded_grads = []
        for p in params:
            if is_model_parallel_parameter(p):
                reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type)
                tensor_parallel_grads.append(p.grad.data / reductor)
            elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded:
                zero_sharded_grads.append(p.grad.data)
            else:
                no_tensor_parallel_grads.append(p.grad.data)

        if norm_type == 2.0 and enable_cuda_kernels:
            tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type
            no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type
            zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type
        else:
            tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
            no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
            zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type)
        # If norm is type of float, then we convert them into torch.Tensor.
        tensor_parallel_norm = _get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels)
        no_tensor_parallel_norm = _get_tensor_norm(no_tensor_parallel_norm, enable_cuda_kernels)
        zero_sharded_norm = _get_tensor_norm(zero_sharded_norm, enable_cuda_kernels)
        # If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors
        if not enable_cuda_kernels:
            tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm)
            no_tensor_parallel_norm = _move_norm_to_cuda(no_tensor_parallel_norm)
            zero_sharded_norm = _move_norm_to_cuda(zero_sharded_norm)

        # Sum across all model-parallel GPUs.
        if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
            dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
        # Sum across all zero sharded GPUs
        if len(zero_sharded_grads) > 0:
            dist.all_reduce(zero_sharded_norm, group=gpc.get_group(ParallelMode.DATA))
            no_tensor_parallel_norm += zero_sharded_norm
        total_norm = tensor_parallel_norm + no_tensor_parallel_norm
        if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
            dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE))
        total_norm = total_norm**(1.0 / norm_type)
        if torch.is_tensor(total_norm):
            total_norm = total_norm.item()

    # Scale.
    clip_coeff = max_norm / (total_norm + 1.0e-6)
    if clip_coeff < 1.0:
        if enable_cuda_kernels:
            grads = [p.grad.detach() for p in params]
            dummy_overflow_buf = torch.cuda.IntTensor([0])
            multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
        else:
            for p in params:
                p.grad.detach().mul_(clip_coeff)
    return total_norm


def count_zeros_fp32(parameters):
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]

    # Filter parameters based on:
    #   - grad should not be none
    #   - parameter should not be shared
    #   - should not be a replica due to tensor model parallelism
    total_num_zeros = 0.0
    for param in parameters:
        grad_not_none = param.grad is not None
        is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
        if grad_not_none and is_not_tp_duplicate:
            grad = param.grad.detach()
            num_zeros = grad.numel() - torch.count_nonzero(grad)
            total_num_zeros = num_zeros + total_num_zeros

    total_num_zeros = torch.IntTensor([int(total_num_zeros)]).cuda()

    # Sum across all model-parallel GPUs.
    ops = []
    ops.append(
        dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True))
    if gpc.is_initialized(ParallelMode.PIPELINE):
        ops.append(
            dist.all_reduce(total_num_zeros,
                            op=dist.ReduceOp.SUM,
                            group=gpc.get_group(ParallelMode.PIPELINE),
                            async_op=True))

    for req in ops:
        req.wait()
    total_num_zeros = total_num_zeros.item()

    return total_num_zeros


def copy_tensor_parallel_attributes(src_tensor, dst_tensor):
    for attr in TENSOR_PARALLEL_ATTRIBUTES:
        if hasattr(src_tensor, attr):
            val = getattr(src_tensor, attr)
            setattr(dst_tensor, attr, val)


def param_is_not_tensor_parallel_duplicate(param):
    return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank(
        ParallelMode.TENSOR) == 0)


@contextmanager
def switch_virtual_pipeline_parallel_rank(rank):
    prev_rank = gpc.virtual_pipeline_parallel_rank
    try:
        gpc.set_virtual_pipeline_parallel_rank(rank)
        yield
    finally:
        gpc.set_virtual_pipeline_parallel_rank(prev_rank)


def disposable(func: Callable) -> Callable:
    executed = False

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        nonlocal executed
        if not executed:
            executed = True
            return func(*args, **kwargs)

    return wrapper