#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import pytest
import torch
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.random import add_seed, seed, set_mode, reset_seeds
from colossalai.utils import checkpoint


def forward(x, weight):
    out = torch.matmul(x, weight)
    with seed(ParallelMode.DATA):
        out_ = F.dropout(out, p=0.4, training=True)
    return out_


@pytest.mark.gpu
@pytest.mark.parametrize("cpu_offload", [True, False])
def test_activation_checkpointing(cpu_offload):
    add_seed(ParallelMode.GLOBAL, 1024)
    add_seed(ParallelMode.DATA, 1026)
    set_mode(ParallelMode.GLOBAL)
    global_cuda_rng_state = torch.cuda.get_rng_state()
    set_mode(ParallelMode.DATA)
    data_parallel_cuda_rng_state = torch.cuda.get_rng_state()
    set_mode(ParallelMode.GLOBAL)

    # normal
    data = torch.rand(2, 2, requires_grad=True).cuda()
    data.retain_grad()
    weight = torch.rand(2, 4, requires_grad=True).cuda()

    data_ = data.clone().detach()
    data_.requires_grad = True
    data_.retain_grad()
    weight_ = weight.clone().detach()
    weight_.requires_grad = True

    out = forward(data, weight)
    loss = out.sum()
    loss.backward()

    # checkpoint
    set_mode(ParallelMode.GLOBAL)
    torch.cuda.set_rng_state(global_cuda_rng_state)
    set_mode(ParallelMode.DATA)
    torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
    set_mode(ParallelMode.GLOBAL)
    out = checkpoint(forward, cpu_offload, data_, weight_)
    loss = out.sum()
    loss.backward()

    assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match'
    torch.cuda.empty_cache()

    # as seed manager is singleton
    # if we don't reset seeds here,
    # other tests will fail if running together with this test
    # as other tests can't overwrite the seed set by this test
    reset_seeds()