import heapq
import inspect
from collections import OrderedDict
from typing import List

import torch

from colossalai.legacy.nn.layer.utils import CheckpointModule
from colossalai.logging import get_dist_logger


def _binary_partition(weights: List, start: int, end: int):
    """Returns the binary partition position of `weights`, given the start
    position `st` and the end position `ed`.

    Args:
        weights (list): A python list to be binary partitioned
        start (int): the start position of the binary partition
        end (int): the end position of the binary partition

    Returns:
        int: the binary partition position of `weights`
    """
    w_sum = weights[end - 1]
    prefix = 0
    if start > 0:
        w_sum -= weights[start - 1]
        prefix = weights[start - 1]
    minimum = float("inf")
    for idx in range(start + 1, end):
        front = weights[idx - 1] - prefix
        diff = abs(w_sum - 2 * front)
        if diff < minimum:
            pos = idx
            minimum = diff

    return start, pos, end


def _heap_addition(weights: List, intervals: int, add_cnt: int):
    """ """

    def _heap_push(heap, st, ed):
        value = weights[ed - 1]
        if st > 0:
            value -= weights[st - 1]
        heapq.heappush(heap, (-value, st, ed))

    ret_intervals = []
    heap = []

    for st, ed in intervals:
        _heap_push(heap, st, ed)

    while add_cnt > 0:
        _, st, ed = heapq.heappop(heap)
        if ed - st == 1:
            ret_intervals.append((st, ed))
        else:
            l, m, r = _binary_partition(weights, st, ed)
            _heap_push(heap, l, m)
            _heap_push(heap, m, r)
            add_cnt -= 1

    while heap:
        _, st, ed = heapq.heappop(heap)
        ret_intervals.append((st, ed))

    ret_intervals.sort()
    return ret_intervals


def _calc_partitions(weights, value):
    prev = 0
    prefix = 0
    num_block = 0
    intervals = []

    for idx, w in enumerate(weights):
        if weights[idx] - prefix > value:
            intervals.append((prev, idx))
            prev = idx
            prefix = weights[idx - 1]
            num_block += 1

    intervals.append((prev, len(weights)))
    return num_block + 1, intervals


def _binary_search(weights, num):
    length = len(weights)
    prefix = [1 if w == 0 else w for w in weights]
    for i in range(1, length):
        prefix[i] += prefix[i - 1]

    lower_bound = max(weights)
    upper_bound = prefix[length - 1]

    while upper_bound > lower_bound:
        mid = (upper_bound + lower_bound) // 2
        number, _ = _calc_partitions(prefix, mid)
        if number <= num:
            upper_bound = mid
        else:
            lower_bound = mid + 1

    num_block, intervals = _calc_partitions(prefix, upper_bound)
    if num_block < num:
        intervals = _heap_addition(prefix, intervals, num - num_block)

    return intervals


def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
    assert (
        num_items % num_chunks == 0
    ), "Layer length should be divided by the number of chunks, otherwise parameter method is recommended"

    logger = get_dist_logger()
    parts = [[] for _ in range(pipeline_parallel_size)]
    partition_items = num_items // num_chunks
    for idx in range(num_chunks):
        base_idx = idx * partition_items
        chunk_size = partition_items // pipeline_parallel_size
        left = pipeline_parallel_size - partition_items % pipeline_parallel_size
        if chunk_size == 0:
            logger.warning("Some nodes in Pipeline have no requests")

        for p in range(pipeline_parallel_size):
            st = base_idx
            base_idx += chunk_size + (p >= left)
            parts[p].append((st, base_idx))

    return parts


def partition_balanced(weights, pipeline_parallel_size, num_chunks):
    num_total = pipeline_parallel_size * num_chunks
    num_items = len(weights)
    if num_items <= num_total:
        return partition_uniform(num_items, pipeline_parallel_size, num_chunks)

    intervals = _binary_search(weights, num_total)

    current = 0
    parts = [[] for _ in range(pipeline_parallel_size)]
    for inter in intervals:
        parts[current].append(inter)
        current = (current + 1) % pipeline_parallel_size

    return parts


def build_kwargs_for_module(function, input_tensor, kw_dict):
    """
    Generally, the first argument of module.forward is an input tensor come from the previous layer.
    Therefore, we just filter the kwargs from second element of the dictionary.
    """
    sig = inspect.signature(function)
    if input_tensor is None:
        kwargs_offset = 0
    elif isinstance(input_tensor, torch.Tensor):
        kwargs_offset = 1
    elif isinstance(input_tensor, (tuple, OrderedDict)):
        # assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
        # Huggingface will take their own structures based on OrderedDict as the output
        # between layers so we've to close this check.
        kwargs_offset = len(input_tensor)
    args_name_list = list(sig.parameters.keys())
    kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[kwargs_offset:]}
    if len(kw_dict) == 0:
        return None
    return kw_dict


def build_kwargs_for_function(function, kw_dict):
    sig = inspect.signature(function)
    kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters}
    if len(kw_dict) == 0:
        return None
    return kw_dict


def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs):
    """
    We suppose the callable object passed to to_layer_list method in two purpose:
        a. use the callable object to modify input tensor, such as \
            lambda x: torch.flatten(x, 1)
        b. use the callable object to modify kwargs value, such as \
            def foo(attention_mask=None):
                if attention_mask is not None:
                    batch_size = input_ids.shape[0]
                    attention_mask = attention_mask.view(batch_size, -1)
                return attention_mask
    """

    if kw_dict is not None:
        rst = func(**kw_dict)
        if isinstance(rst, tuple):
            for i, k in enumerate(kw_dict.keys()):
                kwargs[k] = rst[i]
        else:
            for k in kw_dict.keys():
                kwargs[k] = rst
        return input_tensor
    if isinstance(input_tensor, tuple):
        assert len(input_tensor) > 0, f"input_tensor should not be empty, when kw_dict is None."
        sig = inspect.signature(func)
        func_args_num = len(sig.parameters)
        assert func_args_num <= len(
            input_tensor
        ), f"func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}."
        if func_args_num < len(input_tensor):
            return func(*input_tensor[:func_args_num])
        else:
            return func(*input_tensor)
    assert isinstance(input_tensor, torch.Tensor), "input_tensor should be a type of torch.Tensor or tuple."
    return func(input_tensor)


def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs):
    assert func_key in func_dict, f"{func_key} is not in the function_dict."
    funcs_to_exec = func_dict[func_key]
    if isinstance(funcs_to_exec, list):
        for f in funcs_to_exec:
            f_kwargs = build_kwargs_for_function(f, kwargs)
            input_tensor = exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs)
    else:
        f_kwargs = build_kwargs_for_function(funcs_to_exec, kwargs)
        input_tensor = exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs)

    return input_tensor


def call_module(module, args=None, kwargs=None):
    if args is None:
        args = ()
    if kwargs is None:
        kwargs = {}
    if isinstance(module, CheckpointModule):
        forward_func = module._forward
    else:
        forward_func = module.forward
    sig = inspect.signature(forward_func)
    param_nums = len(sig.parameters)
    len(args) + len(kwargs)
    args_needed_nums = param_nums - len(kwargs)
    args_needed = args[:args_needed_nums]
    if isinstance(module, CheckpointModule):
        convert_kwargs_to_args = []
        for v in kwargs.values():
            convert_kwargs_to_args.append(v)
        return module(*args_needed, *convert_kwargs_to_args)
    else:
        return module(*args_needed, **kwargs)


def customized_partition(exec_seq):
    """
    This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an
    annotation to note the partition point.
    """
    customized_parts = {}
    start = 0
    stop = 0
    rank = 0
    for element in exec_seq:
        if isinstance(element, str):
            if element == "SPLIT_NODE":
                customized_parts[rank] = [(start, stop)]
                start = stop
                rank += 1
            else:
                stop += 1
    customized_parts[rank] = [(start, stop)]
    return customized_parts