import queue
import heapq
from abc import ABC, abstractmethod
from typing import Optional, List, Dict
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState


def evict_check(st: StatefulTensor) -> bool:
    if st.state is not TensorState.COMPUTE and st.device.type == 'cuda':
        return True
    return False


# Here ST means Stateful Tensor
class BaseSTContainer(ABC):
    """A type of container that store all potential stateful tensors which can be evicted from
    CUDA. This kind of stateful tensor should satisfy two conditions. One is that it hasn't been
    evicted, meaning the type of its device is CUDA, the other is that it isn't pinned in CUDA
    memory, meaning its state isn't COMPUTE.

    This container should get a stateful tensor when it become HOLD_LIKE from COMPUTE.
    And it pops stateful tensors in function, `evict_tensors`.

    In order to acquire an optimal eviction policy, users may need to offer computation step
    index of each stateful tensor. So we can use a heap to maintain all potential evictable
    statefule tensors. When poping, we can get the stateful tensor that used furthest in
    current computation step.
    """

    def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
        self.compute_step_dict = compute_step_dict
        self.total_step = total_step

    @abstractmethod
    def empty(self) -> bool:
        pass

    @abstractmethod
    def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
        pass

    @abstractmethod
    def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
        pass

    @abstractmethod
    def pop(self) -> Optional[StatefulTensor]:
        pass


class QueueSTContainer(BaseSTContainer):
    """Queue type stateful tensor container. This is used in 'cpu' tensor placement policy.
    It pops potential evictable stateful tensors in FIFO.
    """

    def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
        super().__init__(compute_step_dict, total_step)
        self.container = None

    def empty(self) -> bool:
        assert self.container is not None
        return self.container.empty()

    def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
        self.container = queue.SimpleQueue()
        for stateful_tensor in stateful_tensor_list:
            self.container.put(stateful_tensor)

    def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
        self.container.put(stateful_tensor)

    def pop(self) -> Optional[StatefulTensor]:
        ret = None
        while not self.empty():
            out_tensor = self.container.get()
            if evict_check(out_tensor):
                ret = out_tensor
                break

        return ret


class HeapSTContainer(BaseSTContainer):
    """Heap type stateful tensor container. This is used in 'auto' tensor placement policy.
    It pops potential evictable stateful tensors in the order of the distance between current
    step and next used step.
    """

    def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
        super().__init__(compute_step_dict, total_step)
        self.container = None

    def empty(self) -> bool:
        assert self.container is not None
        return self.container == []

    def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
        self.container = []
        for stateful_tensor in stateful_tensor_list:
            # we want to pop the tensor which has the greatest next_step
            # so the weight is next_step multiplied by -1
            weight = -self.__get_next_compute_step(stateful_tensor, -1)
            self.container.append((weight, stateful_tensor))
        heapq.heapify(self.container)

    def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
        # we want to pop the tensor which has the greatest next_step
        # so the weight is next_step multiplied by -1
        weight = -self.__get_next_compute_step(stateful_tensor, cur_step)
        heapq.heappush(self.container, (weight, stateful_tensor))

    def pop(self) -> Optional[StatefulTensor]:
        ret = None
        while not self.empty():
            _, out_tensor = heapq.heappop(self.container)
            if evict_check(out_tensor):
                ret = out_tensor
                break
        return ret

    def __get_next_compute_step(self, stateful_tensor: StatefulTensor, cur_step: int):
        # compute the id of next step
        # if the tensor is not used in the furture
        # next_step is set to the maximum
        next_step = self.total_step
        step_list = self.compute_step_dict[stateful_tensor]
        for step in step_list:
            if step > cur_step:
                next_step = step
                break
        return next_step