#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from typing import List, Tuple, Union, Any
import pickle
import io

import torch
import torch.distributed as dist
from torch.distributed import distributed_c10d as c10d
from torch.distributed import ProcessGroupNCCL

from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc

TensorShape = Union[torch.Size, List[int], Tuple[int]]
_pg_manager = {}
_unpickler = pickle.Unpickler


def init_process_group():
    """intialise process group by dist.new_group in the adjacent stages

    Args:
        None
    
    Returns:
        None
    """
    world_size = gpc.get_world_size(ParallelMode.PIPELINE)
    for i in range(world_size - 1):
        _pg_manager[(i, i + 1)] = dist.new_group([i, i + 1])


def _acquire_pair_group_handle(first_rank: int, second_rank: int) -> ProcessGroupNCCL:
    """get the group handle of two given ranks

    Args:
        first_rank (int): first rank in the pair
        second_rank (int): second rank in the pair

    Returns:
        :class:`ProcessGroupNCCL`: the handle of the group consisting of the given two ranks 
    """
    if len(_pg_manager) == 0:
        init_process_group()
    if first_rank > second_rank:
        first_rank, second_rank = second_rank, first_rank
    pair_key = (first_rank, second_rank)
    return _pg_manager[pair_key]


def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object:
    """transform tensor to object with unpickle. 
    Info of the device in bytes stream will be modified into current device before unpickling    

    Args:
        tensor (:class:`torch.tensor`): tensor to be unpickled
        tensor_size (:class:`torch.Size`): Size of the real info in bytes

    Returns:
        Any: object after unpickled
    """
    buf = tensor.numpy().tobytes()[:tensor_size]
    if b'cuda' in buf:
        buf_array = bytearray(buf)
        device_index = torch.cuda.current_device()
        buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index
        buf = bytes(buf_array)

    io_bytes = io.BytesIO(buf)
    byte_pickler = _unpickler(io_bytes)
    unpickle = byte_pickler.load()

    return unpickle


def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=None):
    """This is a modified version of the broadcast_object_list in torch.distribution
    The only difference is that object will be move to correct device after unpickled.
    If local_rank = src, then object list will be sent to rank src. Otherwise, object list will 
    be updated with data sent from rank src.
    
    Args:
        object_list (List[Any]): list of object to broadcast
        src (int): source rank to broadcast
        dst (int): dst rank to broadcast
        device (:class:`torch.device`): device to do broadcast. current device in default

    """
    group = _acquire_pair_group_handle(src, dst)

    if c10d._rank_not_in_group(group):
        c10d._warn_not_in_group("broadcast_object_list")
        return

    local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
    # Serialize object_list elements to tensors on src rank.
    if local_rank == src:
        tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
        object_sizes_tensor = torch.cat(size_list)
    else:
        object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)

    is_nccl_backend = c10d._check_for_nccl_backend(group)
    current_device = None

    if device is not None:
        if is_nccl_backend and device.type != "cuda":
            raise ValueError("device type must be cuda for nccl backend")
        current_device = device
    else:
        current_device = torch.device("cpu")
        if is_nccl_backend:
            current_device = torch.device("cuda", torch.cuda.current_device())
    if is_nccl_backend:
        object_sizes_tensor = object_sizes_tensor.to(current_device)

    # Broadcast object sizes
    c10d.broadcast(object_sizes_tensor, src=src, group=group, async_op=False)

    # Concatenate and broadcast serialized object tensors
    if local_rank == src:
        object_tensor = torch.cat(tensor_list)
    else:
        object_tensor = torch.empty(    # type: ignore[call-overload]
            torch.sum(object_sizes_tensor).item(),    # type: ignore[arg-type]
            dtype=torch.uint8,
        )

    if is_nccl_backend:
        object_tensor = object_tensor.to(current_device)

    c10d.broadcast(object_tensor, src=src, group=group, async_op=False)

    # Deserialize objects using their stored sizes.
    offset = 0

    if local_rank != src:
        for i, obj_size in enumerate(object_sizes_tensor):
            obj_view = object_tensor[offset:offset + obj_size]
            obj_view = obj_view.type(torch.uint8)
            if obj_view.device != torch.device("cpu"):
                obj_view = obj_view.cpu()
            offset += obj_size
            # unpickle
            unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size)

            # unconsistence in device
            if isinstance(unpickle_object,
                          torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
                unpickle_object = unpickle_object.cuda()

            object_list[i] = unpickle_object


def _send_object(object: Any, dst: int) -> None:
    """send anything to dst rank
    Args:
        object (Any): object needed to be sent
        dst (int): rank of the destination

    Returns:
        None
    """
    local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
    # handler = _acquire_pair_group_handle(local_rank, dst)

    # transform to list if not
    if isinstance(object, torch.Tensor):
        object = [object]

    # broadcast length first
    # TODO : more elegant ? P.S. reduce a _broadcast_object_list
    _broadcast_object_list([len(object)], local_rank, dst)
    # then broadcast safely
    _broadcast_object_list(object, local_rank, dst)


def _recv_object(src: int) -> Any:
    """recv anything from src

    Args:
        src (int): source rank of data. local rank will receive data from src rank.
    
    Returns:
        Any: Object received from src.
    """
    local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
    # handler = _acquire_pair_group_handle(local_rank, src)
    # recv length first
    length = [0]
    _broadcast_object_list(length, src, local_rank)

    # then create recv buff from length[0] and broadcast
    object = [None] * length[0]
    _broadcast_object_list(object, src, local_rank)

    if length[0] == 1:
        object = object[0]

    return object


def recv_forward(prev_rank: int = None) -> Any:
    """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.

    Args:
        input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
        prev_rank (int, optional): The rank of the source of the tensor.

    Returns:
        Any: The input tensor or input tensor list.
    """
    if gpc.is_pipeline_first_stage():
        input_tensor = None
    else:
        if prev_rank is None:
            prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
        input_tensor = _recv_object(prev_rank)

    return input_tensor


def recv_backward(next_rank: int = None) -> Any:
    """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.

    Args:
        output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
        next_rank (int, optional): The rank of the source of the tensor.

    Returns:
        Any: The input gradient tensor or gradident tensor list.
    """
    if gpc.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
        if next_rank is None:
            next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
        output_tensor_grad = _recv_object(next_rank)

    return output_tensor_grad


def send_forward(output_object: Any, next_rank: int = None) -> None:
    """Sends the input tensor to the next stage in pipeline.

    Args:
        output_object Any: Object to be sent.
        next_rank (int, optional): The rank of the recipient of the tensor.
    """
    if not gpc.is_pipeline_last_stage():
        if next_rank is None:
            next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
        _send_object(output_object, next_rank)


def send_backward(input_object: Any, prev_rank: int = None) -> None:
    """Sends the gradient tensor to the previous stage in pipeline.

    Args:
        input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent
        prev_rank (int, optional): The rank of the recipient of the tensor
    """
    if not gpc.is_pipeline_first_stage():
        if prev_rank is None:
            prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
        _send_object(input_object, prev_rank)