ColossalAI/colossalai/fx/passes/algorithms/linearize.py

62 lines
1.9 KiB
Python
Raw Normal View History

from typing import List
from torch.fx import GraphModule, Node
def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
"""Linearizing the graph
Args:
gm (GraphModule): GraphModule derived by tracing
cnode (List[str], optional): common node List, should be the subset of input. Default to None.
Returns:
List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner.
"""
def _is_sink() -> bool:
"""Check if we can free all dependencies
Returns:
bool
"""
return not sum([v for _, v in deps.items()])
# make sure that item in cnode is valid
if cnode:
for name in cnode:
try:
assert next(node for node in gm.graph.nodes if node.name == name).op == "placeholder", \
f"common node {name} is not an input of the model"
except StopIteration:
raise ValueError(f"common node name {name} not in graph")
else:
cnode = []
deps = {}
linearized_nodes = []
region = []
for n in gm.graph.nodes:
if n.op != "placeholder" and n.op != "output":
for n_par in n._input_nodes:
if n_par.op != "placeholder" and n_par.name not in cnode:
deps[n_par] -= 1
region.append(n)
# if the node could free all dependencies in graph
# we could begin a new node
if _is_sink():
linearized_nodes.append(region)
region = []
# propagate common node attr if possible
if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]):
cnode.append(n.name)
else:
deps[n] = len([user for user in n.users if user.op != "output"])
return linearized_nodes