2022-09-02 02:24:41 +00:00
|
|
|
from typing import List
|
|
|
|
from torch.fx import GraphModule, Node
|
2022-08-26 02:34:21 +00:00
|
|
|
|
|
|
|
|
2022-09-05 10:35:05 +00:00
|
|
|
def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
|
2022-09-02 02:24:41 +00:00
|
|
|
"""Linearizing the graph
|
2022-08-26 02:34:21 +00:00
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
Args:
|
|
|
|
gm (GraphModule): GraphModule derived by tracing
|
2022-09-05 10:35:05 +00:00
|
|
|
cnode (List[str], optional): common node List, should be the subset of input. Default to None.
|
2022-08-26 02:34:21 +00:00
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
Returns:
|
|
|
|
List[List[Node]]: List of list, each inside list of Node presents
|
|
|
|
the actual 'node' in linearized manner.
|
|
|
|
"""
|
2022-08-26 02:34:21 +00:00
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
def _is_sink() -> bool:
|
|
|
|
"""Check if we can free all dependencies
|
2022-08-26 02:34:21 +00:00
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
Returns:
|
|
|
|
bool
|
|
|
|
"""
|
2022-08-26 02:34:21 +00:00
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
return not sum([v for _, v in deps.items()])
|
2022-08-26 02:34:21 +00:00
|
|
|
|
2022-09-05 10:35:05 +00:00
|
|
|
# make sure that item in cnode is valid
|
|
|
|
if cnode:
|
|
|
|
for name in cnode:
|
|
|
|
try:
|
|
|
|
assert next(node for node in gm.graph.nodes if node.name == name).op == "placeholder", \
|
|
|
|
f"common node {name} is not an input of the model"
|
|
|
|
except StopIteration:
|
|
|
|
raise ValueError(f"common node name {name} not in graph")
|
|
|
|
|
|
|
|
else:
|
|
|
|
cnode = []
|
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
deps = {}
|
|
|
|
linearized_nodes = []
|
|
|
|
region = []
|
2022-08-26 02:34:21 +00:00
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
for n in gm.graph.nodes:
|
2022-09-05 10:35:05 +00:00
|
|
|
if n.op != "placeholder" and n.op != "output":
|
|
|
|
for n_par in n._input_nodes:
|
|
|
|
if n_par.op != "placeholder" and n_par.name not in cnode:
|
|
|
|
deps[n_par] -= 1
|
|
|
|
region.append(n)
|
|
|
|
|
|
|
|
# if the node could free all dependencies in graph
|
|
|
|
# we could begin a new node
|
|
|
|
if _is_sink():
|
|
|
|
linearized_nodes.append(region)
|
|
|
|
region = []
|
|
|
|
|
|
|
|
# propagate common node attr if possible
|
|
|
|
if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]):
|
|
|
|
cnode.append(n.name)
|
|
|
|
else:
|
|
|
|
deps[n] = len([user for user in n.users if user.op != "output"])
|
2022-08-26 02:34:21 +00:00
|
|
|
|
2022-09-02 02:24:41 +00:00
|
|
|
return linearized_nodes
|