from colossalai.tensor import ColoTensor
from colossalai.context.singleton_meta import SingletonMeta


class GraphGlobalEnv(metaclass=SingletonMeta):

    def __init__(self) -> None:
        self.graph_building = False
        self.graph_node_list = []
        self.node_id = -1

    def get_node_id(self):
        self.node_id += 1
        return self.node_id

    def add_graph_node(self, node):
        self.graph_node_list.append(node)


class GraphContext():
    """
    
    Building the computing graph under the context

    >>> with GraphContext():
    >>>     output = model(colo_input_tensor)
    """
    graph_nodes = []

    def __enter__(self):
        GraphGlobalEnv().graph_building = True
        GraphGlobalEnv().graph_node_list = []

    def __exit__(self, *exc_info):
        GraphGlobalEnv().graph_building = False
        GraphGlobalEnv().node_id = -1
        self.graph_nodes = GraphGlobalEnv().graph_node_list


class GraphNode(object):

    def __init__(self) -> None:
        self.prev_nodes = []
        self.post_nodes = []
        self.id = GraphGlobalEnv().get_node_id()

    def add_prev_node(self, node):
        if GraphGlobalEnv().graph_building:
            self.prev_nodes.append(node)

    def add_post_node(self, node):
        if GraphGlobalEnv().graph_building:
            self.post_nodes.append(node)

    def post_node_empty(self) -> bool:
        return len(self.post_nodes) == 0


class GraphOpNode(GraphNode):

    def __init__(self, op_type, param_list) -> None:
        super().__init__()
        self._op_type = op_type
        self._param_list = param_list
        GraphGlobalEnv().add_graph_node(self)

    def add_prev_tensor(self, colo_tensor: ColoTensor):
        r"""
        Link the current graph op node to previous graph op.
        Op1 <- Activation (colo_tensor) Op2
        Op1 <- Op2
        """
        if GraphGlobalEnv().graph_building:
            assert isinstance(colo_tensor, ColoTensor)
            if colo_tensor._graph_node is None:
                colo_tensor._graph_node = GraphNode()
            prev_ops = colo_tensor._graph_node.prev_nodes
            for op_node in prev_ops:
                self.add_prev_node(op_node)
                op_node.add_post_node(self)

    def add_post_tensor(self, colo_tensor: ColoTensor):
        """
        Op <- Activation (colo_tensor)
        """
        if GraphGlobalEnv().graph_building:
            assert isinstance(colo_tensor, ColoTensor), f'type {type(colo_tensor)}'
            if colo_tensor._graph_node is None:
                colo_tensor._graph_node = GraphNode()

            colo_tensor._graph_node.add_prev_node(self)

    def print(self):
        print(
            f'GraphOpNode {self._op_type} {self.id}, post nodes {[node.id for node in self.post_nodes]}, prev node number {[node.id for node in self.prev_nodes]}'
        )