mirror of https://github.com/hpcaitech/ColossalAI
91 lines
3.0 KiB
Python
91 lines
3.0 KiB
Python
from typing import List, Any
|
|
from torch.fx import GraphModule, Node
|
|
|
|
# Common nodes are type of nodes that could be seen as attributes and remain
|
|
# unchanged throughout the whole model, it will be used several times by
|
|
# different blocks of model, so that it is hard for us to linearize the graph
|
|
# when we encounter those kinds of nodes. We let users to annotate some of the
|
|
# input as common node, such as attention mask, and the followings are some of
|
|
# the ops that could actually be seen as common nodes. With our common node prop,
|
|
# we could find some of the "real" common nodes (e.g. the real attention mask
|
|
# used in BERT and GPT), the rule is simple, for node who's parents are all common
|
|
# nodes or it's op belongs to the following operations, we view this node as a
|
|
# newly born common node.
|
|
# List of target name that could be seen as common node
|
|
COPS = ["getattr", "getitem", "size"]
|
|
|
|
|
|
def _is_cop(target: Any) -> bool:
|
|
"""Check if an op could be seen as common node
|
|
|
|
Args:
|
|
target (Any): node target
|
|
|
|
Returns:
|
|
bool
|
|
"""
|
|
|
|
if isinstance(target, str):
|
|
return target in COPS
|
|
else:
|
|
return target.__name__ in COPS
|
|
|
|
|
|
def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
|
|
"""Linearizing the graph
|
|
|
|
Args:
|
|
gm (GraphModule): GraphModule derived by tracing
|
|
cnode (List[str], optional): common node List, should be the subset of input. Default to None.
|
|
|
|
Returns:
|
|
List[List[Node]]: List of list, each inside list of Node presents
|
|
the actual 'node' in linearized manner.
|
|
"""
|
|
|
|
def _is_sink() -> bool:
|
|
"""Check if we can free all dependencies
|
|
|
|
Returns:
|
|
bool
|
|
"""
|
|
|
|
return not sum([v for _, v in deps.items()])
|
|
|
|
# make sure that item in cnode is valid
|
|
if cnode:
|
|
for name in cnode:
|
|
try:
|
|
assert next(node for node in gm.graph.nodes if node.name == name).op == "placeholder", \
|
|
f"common node {name} is not an input of the model"
|
|
except StopIteration:
|
|
raise ValueError(f"common node name {name} not in graph")
|
|
|
|
else:
|
|
cnode = []
|
|
|
|
deps = {}
|
|
linearized_nodes = []
|
|
region = []
|
|
|
|
for n in gm.graph.nodes:
|
|
if n.op != "placeholder" and n.op != "output":
|
|
for n_par in n._input_nodes:
|
|
if n_par.op != "placeholder" and n_par.name not in cnode:
|
|
deps[n_par] -= 1
|
|
region.append(n)
|
|
|
|
# if the node could free all dependencies in graph
|
|
# we could begin a new node
|
|
if _is_sink():
|
|
linearized_nodes.append(region)
|
|
region = []
|
|
|
|
# propagate common node attr if possible
|
|
if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]) or _is_cop(n.target):
|
|
cnode.append(n.name)
|
|
else:
|
|
deps[n] = len([user for user in n.users if user.op != "output"])
|
|
|
|
return linearized_nodes
|