from enum import Enum
from typing import Optional
import torch
from typing import Union

from colossalai.gemini.gemini_context import GeminiMemoryManager


def sizeof_tensor(tensor: torch.Tensor):
    return tensor.numel() * tensor.element_size()


class TensorState(Enum):
    FREE = 0
    HOLD = 1
    HOLD_AFTER_FWD = 2
    HOLD_AFTER_BWD = 3
    COMPUTE = 4


class StatefulTensor(object):
    """A Structure stores a Torch Tensor and labeled states. 
    Inspired from the paper:
    PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management

    https://arxiv.org/abs/2108.05818
    """
    # Global Stateful Tensor Manager
    GST_MGR = GeminiMemoryManager(TensorState)

    def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
        self._state = state
        self._payload = None
        self._payload_size = 0    # byte size of current payload

        StatefulTensor.GST_MGR.register_new_instance()

        if self._state == TensorState.FREE:
            # when the state is free, payload should be None
            assert maybe_tensor is None, f"payload has to None if state is {self._state}"
        else:
            # otherwise, payload should not be None
            assert maybe_tensor is not None, f"payload can't be None if state is {self._state}"
            self._payload = maybe_tensor
            self._payload_size = sizeof_tensor(maybe_tensor)
            self.__trans_state_update(TensorState.FREE, state)

    def data_ptr(self):
        if self._payload is None:
            return 0    # if a tensor has no storage, 0 should be returned
        return self._payload.data_ptr()

    def set_null(self) -> None:
        # notice that free stateful tensor do not need to become null again
        if self.state != TensorState.FREE:
            self.__trans_state_update(self.state, TensorState.FREE)
            self.__release()

    def is_null(self) -> bool:
        if self.state == TensorState.FREE:
            # check sanity here
            assert self.payload is None
            return True
        return False

    def trans_state(self, state: TensorState) -> None:
        if self.state == TensorState.FREE:
            # free stateful tensor can't change state
            assert state == TensorState.FREE, "Free stateful tensor can't change to other states"
            return

        self.__trans_state_update(self.state, state)

        if state == TensorState.FREE:
            self.__release()
        else:
            self._state = state

    def move_to(self, device: Union[torch.device, int]):
        assert self.state is not TensorState.FREE, "Can't move free stateful tensor"

        if not isinstance(device, torch.device):
            to_device = torch.device('cuda', device)
        else:
            to_device = device

        from_device_type = self.device.type
        if from_device_type == to_device.type:
            # from device == to device
            return

        # update manager's information
        self.__trans_device_update(from_device_type, to_device.type)
        self.payload.data = self.payload.data.to(to_device)

    def payload_copy(self, tensor) -> None:
        self._payload.view(-1).copy_(tensor.view(-1))

    def payload_reset(self, tensor) -> None:

        assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead"

        if self.payload is not None:
            # release old payload
            self.__trans_state_update(self.state, TensorState.FREE)
        else:
            # otherwise, set the state to HOLD for new payload
            self._state = TensorState.HOLD
        del self._payload

        self._payload = tensor
        self._payload_size = sizeof_tensor(tensor)
        # record new payload
        self.__trans_state_update(TensorState.FREE, self.state)

    def payload_relay(self, rhs):
        # relay the payload of rhs to current stateful tensor
        # can't support null relay right now
        assert not rhs.is_null()

        # now this function only support stateful tensor that has zero-length payload
        # because it doesn't require memory manager updating
        # you can extend this function by yourself
        assert self.payload_size == 0

        self._payload = rhs.payload
        self._payload_size = rhs.payload_size
        self._state = TensorState.HOLD
        self.__trans_state_update(rhs.state, TensorState.HOLD)

        rhs.__release()

    @property
    def payload(self) -> Optional[torch.Tensor]:
        return self._payload

    @property
    def payload_size(self) -> int:
        return self._payload_size

    @property
    def state(self) -> TensorState:
        return self._state

    @property
    def device(self) -> torch.device:
        return self._payload.device

    @property
    def dtype(self) -> torch.dtype:
        return self._payload.dtype

    @property
    def shape(self):
        return self._payload.shape

    def to(self, device: torch.device):
        raise RuntimeError("Use move_to(...) instead of call .to() on StatefulTensor")

    def to_(self, device: torch.device):
        raise RuntimeError("Use move_to(...) instead of call .to_() on StatefulTensor")

    def __release(self):
        # release current payload
        # shouldn't be visible to users
        self._state = TensorState.FREE
        self._payload = None
        self._payload_size = 0

    def __trans_state_update(self, from_state: TensorState, to_state: TensorState):
        """Update global manager when changing the state of a tensor
        """
        manager = StatefulTensor.GST_MGR
        size = self.payload_size
        device_type = self.device.type

        if from_state != TensorState.FREE:
            manager.state_mem[device_type][from_state] -= size
        else:
            # when from_state is FREE, the tensor is new to manager
            # we should add its memory
            manager.total_mem[device_type] += size

        if to_state != TensorState.FREE:
            manager.state_mem[device_type][to_state] += size
        else:
            # when to_state is FREE, the tensor will be deleted soon
            # we should sub its memory
            manager.total_mem[device_type] -= size

    def __trans_device_update(self, from_type: str, to_type: str):
        """Update global manager when changing the device of a tensor
        """
        manager = StatefulTensor.GST_MGR
        size = self.payload_size
        state = self.state

        # update aggregated information
        manager.total_mem[from_type] -= size
        manager.total_mem[to_type] += size

        # update the information of each state
        manager.state_mem[from_type][state] -= size
        manager.state_mem[to_type][state] += size

    def __del__(self):
        self.set_null()
        StatefulTensor.GST_MGR.delete_instance()
        del self