import pytest
import torch

from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor
from colossalai.gemini.stateful_tensor_container import QueueSTContainer, HeapSTContainer


@pytest.mark.dist
def test_stateful_tensor_container():
    st1 = StatefulTensor(torch.randn(1, device='cuda'))
    st2 = StatefulTensor(torch.randn(2, device='cuda'))
    st3 = StatefulTensor(torch.randn(3, device='cuda'))
    stateful_tensor_list = [st1, st2, st3]
    step_list = [st1, st2, st3, st3, st2, st1]

    compute_step_dict = dict()
    compute_step_dict[st1] = [0, 5]
    compute_step_dict[st2] = [1, 4]
    compute_step_dict[st3] = [2, 3]

    def run_queue_test():
        # test queue container
        queue_container = QueueSTContainer(compute_step_dict, 6)
        queue_container.create(stateful_tensor_list)

        res_list = []

        for i in range(6):
            stateful_tensor = step_list[i]
            stateful_tensor.trans_state(TensorState.COMPUTE)
            st_out = queue_container.pop()
            st_out.move_to(torch.device('cpu'))

            res_list.append(st_out.payload.size(0))

            stateful_tensor.move_to(torch.device('cuda'))
            queue_container.push(stateful_tensor, i)
            stateful_tensor.trans_state(TensorState.HOLD)

        assert res_list == [2, 3, 1, 2, 3, 2]

    run_queue_test()

    def run_heap_test():
        # test heap container
        st1.move_to(torch.device('cuda'))
        st2.move_to(torch.device('cuda'))
        st3.move_to(torch.device('cuda'))

        heap_container = HeapSTContainer(compute_step_dict, 6)
        heap_container.create(stateful_tensor_list)

        res_list = []

        for i in range(6):
            stateful_tensor = step_list[i]
            stateful_tensor.trans_state(TensorState.COMPUTE)
            st_out = heap_container.pop()

            if st_out is not None:
                res_list.append(st_out.payload.size(0))
                st_out.move_to(torch.device('cpu'))

            stateful_tensor.move_to(torch.device('cuda'))
            heap_container.push(stateful_tensor, i)
            stateful_tensor.trans_state(TensorState.HOLD)

        assert res_list == [3, 1, 2, 3, 2]

    run_heap_test()


if __name__ == '__main__':
    test_stateful_tensor_container()