2022-09-02 02:24:41 +00:00
|
|
|
from typing import List
|
|
|
|
from torch.fx import GraphModule, Node
|
2022-08-26 02:34:21 +00:00
|
|
|
|
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
def linearize(gm: GraphModule) -> List[List[Node]]:
|
|
|
|
"""Linearizing the graph
|
2022-08-26 02:34:21 +00:00
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
Args:
|
|
|
|
gm (GraphModule): GraphModule derived by tracing
|
2022-08-26 02:34:21 +00:00
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
Returns:
|
|
|
|
List[List[Node]]: List of list, each inside list of Node presents
|
|
|
|
the actual 'node' in linearized manner.
|
|
|
|
"""
|
2022-08-26 02:34:21 +00:00
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
def _is_sink() -> bool:
|
|
|
|
"""Check if we can free all dependencies
|
2022-08-26 02:34:21 +00:00
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
Returns:
|
|
|
|
bool
|
|
|
|
"""
|
2022-08-26 02:34:21 +00:00
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
return not sum([v for _, v in deps.items()])
|
2022-08-26 02:34:21 +00:00
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
deps = {}
|
|
|
|
linearized_nodes = []
|
|
|
|
region = []
|
2022-08-26 02:34:21 +00:00
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
for n in gm.graph.nodes:
|
|
|
|
for n_par in n._input_nodes:
|
|
|
|
deps[n_par] -= 1
|
|
|
|
region.append(n)
|
2022-08-26 02:34:21 +00:00
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
# if the node could free all dependencies in graph
|
|
|
|
# we could begin a new node
|
|
|
|
if _is_sink():
|
|
|
|
linearized_nodes.append(region)
|
|
|
|
region = []
|
2022-08-26 02:34:21 +00:00
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
deps[n] = len(n.users)
|
2022-08-26 02:34:21 +00:00
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
# Remove input
|
|
|
|
linearized_nodes = linearized_nodes[1:-1]
|
|
|
|
return linearized_nodes
|