import pytest
from torch import nn
import torch
from colossalai.tensor import ColoTensor
from colossalai.tensor.graph import GraphContext
import gc


class SimpleNet(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.proj1 = nn.Linear(4, 8)
        self.proj2 = nn.Linear(8, 4)
        self.proj3 = nn.Linear(4, 4)
        self.proj4 = nn.Linear(4, 4)

    def forward(self, x):
        x = self.proj1(x)
        x = self.proj2(x)
        x = self.proj3(x)
        x = self.proj4(x)
        return x


def _visit_graph(start_node):
    if start_node is None:
        return

    start_node.print()

    post_node_list = start_node.post_nodes
    for node in post_node_list:
        _visit_graph(node)


def _get_tensors():
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj):
                yield obj
        except Exception as e:
            print('A trivial exception occured: {}'.format(e))


def _count_tensors():
    cnt = 0
    for t in _get_tensors():
        cnt += 1
    return cnt


def count_tensors(use_colossal):
    model = SimpleNet()

    model.eval()
    with torch.no_grad():
        if use_colossal:
            colo_input = ColoTensor.from_torch_tensor(torch.randn(4))
            graph_ctx = GraphContext()
            with graph_ctx:
                output = model(colo_input)
            output = model(colo_input)
            ret = _count_tensors()

            _visit_graph(graph_ctx.graph_nodes[0])

            del graph_ctx
            return ret
        else:
            input_t = torch.randn(4)
            output = model(input_t)
            output = model(input_t)
            return _count_tensors()


@pytest.mark.skip
# FIXME(ver217)
def test_check_activation_tensors():
    assert count_tensors(False) == count_tensors(True)


if __name__ == "__main__":
    count_tensors(True)