#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import torch
from torch.utils.checkpoint import check_backward_validity, detach_variable

from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states
from .cuda import get_current_device

import weakref


def copy_to_device(obj, device):
    if torch.is_tensor(obj):
        # Notice:
        # When in no_grad context, requires_gard is False after movement
        ret = obj.to(device).detach()
        ret.requires_grad = obj.requires_grad
        return ret
    elif isinstance(obj, list):
        return [copy_to_device(i, device) for i in obj]
    elif isinstance(obj, tuple):
        return tuple([copy_to_device(v, device) for v in obj])
    elif isinstance(obj, dict):
        return {k: copy_to_device(v, device) for k, v in obj.items()}
    else:
        return obj


class CheckpointFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, run_function, activation_offload=False, *args):
        check_backward_validity(args)
        ctx.run_function = run_function
        ctx.activation_offload = activation_offload
        ctx.device = get_current_device()

        # preserve rng states
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
        sync_states()
        ctx.fwd_seed_states = get_states(copy=True)
        ctx.fwd_current_mode = get_current_mode()

        if hasattr(torch, 'is_autocast_enabled'):
            ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
        else:
            ctx.had_autocast_in_fwd = False

        if activation_offload:
            inputs_cuda = copy_to_device(args, ctx.device)
        else:
            inputs_cuda = args

        with torch.no_grad():
            outputs = run_function(*inputs_cuda)
        # Save non-tensor inputs in ctx, keep a placeholder None for tensors
        # to be filled out during the backward.
        ctx.inputs = []
        ctx.tensor_indices = []
        tensor_inputs = []
        for i, arg in enumerate(args):
            if torch.is_tensor(arg):
                if activation_offload:
                    tensor_inputs.append(copy_to_device(arg, 'cpu'))
                else:
                    tensor_inputs.append(arg)
                ctx.tensor_indices.append(i)
                ctx.inputs.append(None)
            else:
                ctx.inputs.append(arg)

        if activation_offload:
            ctx.tensor_inputs = tensor_inputs
        else:
            ctx.save_for_backward(*tensor_inputs)
        return outputs

    @staticmethod
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad() or when an `inputs` parameter is "
                               "passed to .backward(). Please use .backward() and do not pass its `inputs` argument.")
        # Copy the list to avoid modifying original list.
        inputs = list(ctx.inputs)
        tensor_indices = ctx.tensor_indices

        if ctx.activation_offload:
            tensors = ctx.tensor_inputs
        else:
            tensors = ctx.saved_tensors

        # store the current states
        bwd_cpu_rng_state = torch.get_rng_state()
        sync_states()
        bwd_seed_states = get_states(copy=True)
        bwd_current_mode = get_current_mode()

        # set the states to what it used to be
        torch.set_rng_state(ctx.fwd_cpu_rng_state)
        for parallel_mode, state in ctx.fwd_seed_states.items():
            set_seed_states(parallel_mode, state)
        set_mode(ctx.fwd_current_mode)
        if ctx.activation_offload:
            tensors = copy_to_device(tensors, ctx.device)

        # Fill in inputs with appropriate saved tensors.
        for i, idx in enumerate(tensor_indices):
            inputs[idx] = tensors[i]
        detached_inputs = detach_variable(tuple(inputs))
        if ctx.had_autocast_in_fwd:
            with torch.enable_grad(), torch.cuda.amp.autocast():
                outputs = ctx.run_function(*detached_inputs)
        else:
            with torch.enable_grad():
                outputs = ctx.run_function(*detached_inputs)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
        # recover the rng states
        torch.set_rng_state(bwd_cpu_rng_state)
        for parallel_mode, state in bwd_seed_states.items():
            set_seed_states(parallel_mode, state)
        set_mode(bwd_current_mode)

        # run backward() with only tensor that requires grad
        outputs_with_grad = []
        args_with_grad = []
        for i in range(len(outputs)):
            if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
                outputs_with_grad.append(outputs[i])
                args_with_grad.append(args[i])
        if len(outputs_with_grad) == 0:
            raise RuntimeError("none of output has requires_grad=True,"
                               " this checkpoint() is not necessary")
        torch.autograd.backward(outputs_with_grad, args_with_grad)
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs)
        return (None, None) + grads


def checkpoint(function, activation_offload, *args, use_reentrant: bool = True):
    """Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint.

    Args:
        function: Describe the forward pass function. It should know how to handle the input tuples.
        activation_offload: The variable to check whether we should offload activation to cpu 
        args (list): Tuple containing the parameters of the function
        use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there
        might be more flexibility for user to define there checkpoint function

    Returns:
        Output of running function with provided args.
    """
    if use_reentrant:
        return CheckpointFunction.apply(function, activation_offload, *args)
    else:
        return _checkpoint_without_reentrant(
            function,
            activation_offload,
            *args,
        )


def _checkpoint_without_reentrant(function, activation_offload=False, *args):
    # store rng_state
    fwd_cpu_state = torch.get_rng_state()
    sync_states()
    fwd_seed_states = get_states(copy=True)
    fwd_current_mode = get_current_mode()

    # check if use autocast
    if hasattr(torch, 'is_autocast_enabled'):
        has_autocast_in_fwd = torch.is_autocast_enabled()
    else:
        has_autocast_in_fwd = False

    # using WeakKeyDictionary to store all the activation the first time we call unpack
    storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
    weak_holder_list = []

    # class for weakref.ref
    class Holder():
        pass

    # return a Holder object for later unpack process
    def pack(x):
        res = Holder()
        weak_holder_list.append(weakref.ref(res))
        return res

    # unpack hook
    def unpack(x):
        unpack_counter = 0

        # re-compute all the activation inside the function when we first call unpack
        if len(storage) == 0:

            def inner_pack(inner):
                nonlocal unpack_counter
                unpack_counter += 1

                # If the holder went out of scope, the SavedVariable is dead and so
                # the value will never be read from the storage. Skip filling it.
                if weak_holder_list[unpack_counter - 1]() is None:
                    return

                # Use detach here to ensure we don't keep the temporary autograd
                # graph created during the second forward
                storage[weak_holder_list[unpack_counter - 1]()] = inner.detach()
                return

            def inner_unpack(packed):
                raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")

            # restore rng state
            torch.set_rng_state(fwd_cpu_state)
            for parallel_mode, state in fwd_seed_states.items():
                set_seed_states(parallel_mode, state)
            set_mode(fwd_current_mode)

            # reload arg into device if needed
            if activation_offload:
                for arg in args:
                    if torch.is_tensor(arg):
                        arg = arg.to(device=device)

            # rerun forward, the inner_pack will store all the activations in storage
            if has_autocast_in_fwd:
                with torch.enable_grad(), \
                     torch.cuda.amp.autocast(), \
                     torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
                    _unused = function(*args)
            else:
                with torch.enable_grad(), \
                     torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
                    _unused = function(*args)

        if x not in storage:
            raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
                               " recomputation being triggered in between, this is not currently supported. Please"
                               " open an issue with details on your use case so that we can prioritize adding this.")

        return storage[x]

    # get device if we need to offload the activation
    if activation_offload:
        device = get_current_device()

    # run function with pack and unpack as saved_tensors_hooks
    with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
        output = function(*args)

        # offload activation if needed
        if activation_offload:
            for arg in args:
                if torch.is_tensor(arg):
                    arg = arg.to(device="cpu")

    return output