from typing import List

import torch
from torch.fx.node import Node

from .region import Region
from .util import GlobalRuntimeInfo, requires_upload_p_in_fwd


class SynPreFwdPostBwdOP(torch.autograd.Function):
    """
    A customized prefetch and offload operation.

    Args:
        input_: input tensor.
        fwd_info: information dict, which contains region indices
            that need to be uploaded or freed during forward pass.
        bwd_info: information dict, which contains region indices
            that need to be uploaded during backward pass.
    """

    @staticmethod
    def forward(ctx, input_, fwd_info, bwd_info):
        ctx.bwd_info = bwd_info
        d2h_rid = fwd_info.get('d2h_rid', None)
        if d2h_rid is not None:
            free_region = GlobalRuntimeInfo().region_list[d2h_rid]
            assert isinstance(free_region, Region)
            free_region.free_cuda_data()

        h2d_rid = fwd_info.get('h2d_rid', None)
        if h2d_rid is not None:
            h2d_region = GlobalRuntimeInfo().region_list[h2d_rid]
            assert isinstance(h2d_region, Region)
            h2d_region.move_param_to_cuda()

        return input_

    @staticmethod
    def backward(ctx, grad_output):

        h2d_rid = ctx.bwd_info.get('h2d_rid', None)
        if h2d_rid is not None:
            pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
            assert isinstance(pref_region, Region)
            pref_region.move_param_to_cuda()

        return grad_output, None, None


class AsynPreFwdPostBwdOP(torch.autograd.Function):
    """
    A customized prefetch and offload operation.

    Args:
        input_: input tensor.
        fwd_info: information dict, which contains region indices
            that need to be prefetched, waited, or freed during forward pass.
        bwd_info: information dict, which contains region indices
            that need to be prefetched or waited during backward pass.
    """

    @staticmethod
    def forward(ctx, input_, fwd_info, bwd_info):
        ctx.bwd_info = bwd_info

        sync_rid = fwd_info.get('sync_rid', None)
        if sync_rid is not None:
            prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None)
            if prefetch_event:
                prefetch_event.wait()

        h2d_rid = fwd_info.get('h2d_rid', None)
        if h2d_rid is not None:
            pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
            assert isinstance(pref_region, Region)
            master_stream = torch.cuda.current_stream()
            with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream):
                GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream)
                pref_region.move_param_to_cuda()

            prefetch_event = torch.cuda.Event()
            prefetch_event.record(GlobalRuntimeInfo().h2d_stream)
            GlobalRuntimeInfo().fwd_prefetch_event_map[h2d_rid] = prefetch_event

        return input_

    @staticmethod
    def backward(ctx, grad_output):

        sync_rid = ctx.bwd_info.get('sync_rid', None)
        if sync_rid is not None:
            wait_region = GlobalRuntimeInfo().region_list[sync_rid]
            assert isinstance(wait_region, Region)
            prefetch_event = GlobalRuntimeInfo().bwd_prefetch_event_map.get(sync_rid, None)
            if prefetch_event:
                prefetch_event.wait()
            else:
                wait_region.move_param_to_cuda()

        h2d_rid = ctx.bwd_info.get('h2d_rid', None)
        if h2d_rid is not None:
            pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
            assert isinstance(pref_region, Region)
            master_stream = torch.cuda.current_stream()
            with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream):
                GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream)
                pref_region.move_param_to_cuda()

            prefetch_event = torch.cuda.Event()
            prefetch_event.record(GlobalRuntimeInfo().h2d_stream)
            GlobalRuntimeInfo().bwd_prefetch_event_map[h2d_rid] = prefetch_event
        return grad_output, None, None


def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
    '''
    Convert Upload and Offload operation into runtime action.

    Argument:
        tensor(torch.Tensor): input tensor.
        fwd_info(dict): information dict, which contains region indices
            that need to be uploaded, or freed during forward pass.
        bwd_info(dict): information dict, which contains region indices
            that need to be uploaded during backward pass.
    '''
    with torch._C.DisableTorchFunction():
        ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
    return ret


def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
    '''
    Convert Prefetch and Offload operation into runtime action.

    Argument:
        tensor(torch.Tensor): input tensor.
        fwd_info(dict): information dict, which contains region indices
            that need to be prefetched, waited, or freed during forward pass.
        bwd_info(dict): information dict, which contains region indices
            that need to be prefetched or waited during backward pass.
    '''
    with torch._C.DisableTorchFunction():
        ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
    return ret


def replace_node_users(orig_node: Node, inserted_node: Node, rep_user_nodes: List[Node] = None):
    user_list = list(orig_node.users.keys())
    if rep_user_nodes is not None:
        user_list = rep_user_nodes
    for user in user_list:
        if user == inserted_node:
            continue
        new_args = list(user.args)
        new_kwargs = dict(user.kwargs)
        # the origin node may be a positional argument or key word argument of user node
        if orig_node in new_args:
            # substitute the origin node with offload_apply_node
            new_args[new_args.index(orig_node)] = inserted_node
            user.args = tuple(new_args)
        elif str(orig_node) in new_kwargs:
            # substitute the origin node with offload_apply_node
            new_kwargs[str(orig_node)] = inserted_node
            user.kwargs = new_kwargs


def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]):
    """
    This pass is used to add the synchronous upload and offload spec apply node to the origin graph.
    """
    mod_graph = gm.graph
    last_inp_node = tuple(mod_graph.nodes)[0]

    for r_idx, region in enumerate(region_list):
        # forward upload
        fwd_info = {}
        if requires_upload_p_in_fwd(region_list[region.shared_rid]):
            fwd_info['h2d_rid'] = region.r_id

        # forward offload
        if r_idx > 0 and region_list[r_idx - 1].need_offload:
            fwd_info['d2h_rid'] = r_idx - 1

        bwd_info = {}
        # backward upload
        if r_idx > 0 and region_list[r_idx - 1].need_offload:
            bwd_info['h2d_rid'] = region_list[r_idx - 1].r_id

        if fwd_info or bwd_info:
            with mod_graph.inserting_after(last_inp_node):
                new_node = mod_graph.create_node('call_function',
                                                 convert_fwd_upload_bwd_offload_to_action,
                                                 args=(last_inp_node, fwd_info, bwd_info))
            replace_node_users(last_inp_node, new_node)

        last_inp_node = region.nodes[-1]

    return gm


def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]):
    """
    This pass is used to add the asynchronous prefetch and offload spec apply node to the origin graph.
    """
    mod_graph = gm.graph

    # upload parameters of the first region
    last_inp_node = tuple(mod_graph.nodes)[0]
    first_region_with_p = [region for region in region_list if region.param_size][0]
    fwd_info = {"h2d_rid": first_region_with_p.r_id}
    with mod_graph.inserting_after(last_inp_node):
        upload_apply_node = mod_graph.create_node('call_function',
                                                  convert_fwd_upload_bwd_offload_to_action,
                                                  args=(last_inp_node, fwd_info, {}))
    replace_node_users(last_inp_node, upload_apply_node)
    last_inp_node = upload_apply_node

    for r_idx, region in enumerate(region_list):
        # forward prefetch
        fwd_info = {}
        if region.param_size:
            fwd_info['sync_rid'] = region.r_id
        fwd_prefetch_region = region.fwd_prefetch_region
        if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]):
            fwd_info['h2d_rid'] = fwd_prefetch_region.r_id

        # forward offload
        if r_idx > 0 and region_list[r_idx - 1].need_offload:
            fwd_info['d2h_rid'] = r_idx - 1

        bwd_info = {}
        # backward prefetch
        if r_idx > 0 and region_list[r_idx - 1].need_offload:
            bwd_info['sync_rid'] = r_idx - 1
        if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region:
            bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id

        if fwd_info or bwd_info:
            with mod_graph.inserting_after(last_inp_node):
                new_node = mod_graph.create_node('call_function',
                                                 convert_fwd_prefetch_bwd_offload_to_action,
                                                 args=(last_inp_node, fwd_info, bwd_info))
            replace_node_users(last_inp_node, new_node)

        last_inp_node = region.nodes[-1]

    if region.bwd_prefetch_region:
        bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id}
        with mod_graph.inserting_after(last_inp_node):
            new_node = mod_graph.create_node('call_function',
                                             convert_fwd_prefetch_bwd_offload_to_action,
                                             args=(last_inp_node, {}, bwd_info))
        replace_node_users(last_inp_node, new_node)
    # gm.graph.print_tabular()
    return gm