mirror of https://github.com/hpcaitech/ColossalAI
[fx] methods to get fx graph property. (#1246)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* manipulation
* [fx]add graph manipulation methods.
* [fx]methods to get fx graph property.
* add unit test
* add docstring to explain top node and leaf node in this context
pull/1271/head
parent
30b4fc0eb0
commit
97d713855a
|
@ -10,6 +10,7 @@ def pipe_split():
|
|||
|
||||
|
||||
def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
# TODO(lyl): balanced policy V2, split module by node size(weight+bias+output)
|
||||
mod_graph = gm.graph
|
||||
total_param_amount = 0
|
||||
for param in mod_graph.owning_module.parameters():
|
||||
|
@ -68,6 +69,9 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
|||
|
||||
|
||||
def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
|
||||
# TODO(lyl): use partition IR to assign partition ID to each node.
|
||||
# Currently: analyzing graph -> annotate graph by inserting split node -> use split module pass to split graph
|
||||
# In future: graph to partitions -> analyzing partition IR -> recombining partitions to get best performance -> assign partition ID to each node
|
||||
part_idx = 0
|
||||
|
||||
def split_callback(n: torch.fx.Node):
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import torch
|
||||
from typing import Dict, Set
|
||||
from torch.fx.node import Node, map_arg
|
||||
from torch.fx.graph import Graph
|
||||
|
||||
|
||||
def get_comm_size(prev_partition, next_partition):
|
||||
"""Given two partitions (parent and child),
|
||||
"""
|
||||
Given two partitions (parent and child),
|
||||
calculate the communication size between the two.
|
||||
"""
|
||||
# Keep tracking the communication size between parent and child
|
||||
|
@ -25,3 +27,136 @@ def get_comm_size(prev_partition, next_partition):
|
|||
comm_size += n.meta['tensor_meta'].numel
|
||||
visited_nodes.add(n)
|
||||
return comm_size
|
||||
|
||||
|
||||
def get_leaf(graph: Graph):
|
||||
"""
|
||||
Given a graph, return leaf nodes of this graph.
|
||||
|
||||
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph,
|
||||
we will get a normal DAG. Leaf nodes in this context means leaf nodes in that DAG.
|
||||
"""
|
||||
input_nodes: Dict[Node, None] = {}
|
||||
for node in graph.nodes:
|
||||
if node.op == 'output':
|
||||
map_arg(node.args, lambda n: input_nodes.setdefault(n))
|
||||
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
|
||||
placeholder_nodes = []
|
||||
for node in input_nodes.keys():
|
||||
if node.op == 'placeholder':
|
||||
placeholder_nodes.append(node)
|
||||
for node in placeholder_nodes:
|
||||
input_nodes.pop(node)
|
||||
return list(input_nodes.keys())
|
||||
|
||||
|
||||
def is_leaf(graph: Graph, node: Node):
|
||||
return node in get_leaf(graph)
|
||||
|
||||
|
||||
def get_top(graph: Graph):
|
||||
"""
|
||||
Given a graph, return top nodes of this graph.
|
||||
|
||||
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph,
|
||||
we will get a normal DAG. Top nodes in this context means nodes with BFS level 0 in that DAG.
|
||||
"""
|
||||
top_node_list = set()
|
||||
for node in graph.nodes:
|
||||
if node.op == 'output':
|
||||
continue
|
||||
is_top = False
|
||||
|
||||
def _get_top(node):
|
||||
nonlocal is_top
|
||||
if node.op == 'placeholder':
|
||||
is_top = True
|
||||
|
||||
map_arg(node.args, lambda n: _get_top(n))
|
||||
map_arg(node.kwargs, lambda n: _get_top(n))
|
||||
if is_top:
|
||||
top_node_list.add(node)
|
||||
return list(top_node_list)
|
||||
|
||||
|
||||
def is_top(graph: Graph, node: Node):
|
||||
return node in get_top(graph)
|
||||
|
||||
|
||||
def get_all_consumers(graph: Graph, node: Node):
|
||||
"""
|
||||
Given a graph and a node of this graph, return all consumers of the node.
|
||||
|
||||
Returns:
|
||||
List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``.
|
||||
"""
|
||||
consumer_list = []
|
||||
for n in graph.nodes:
|
||||
if node in n.all_input_nodes:
|
||||
consumer_list.append(n)
|
||||
return consumer_list
|
||||
|
||||
|
||||
def assign_bfs_level_to_nodes(graph: Graph):
|
||||
"""
|
||||
Give a graph, assign bfs level to each node of this graph excluding ``placeholder`` and ``output`` nodes.
|
||||
|
||||
Example:
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(dim, dim)
|
||||
self.linear2 = torch.nn.Linear(dim, dim)
|
||||
self.linear3 = torch.nn.Linear(dim, dim)
|
||||
self.linear4 = torch.nn.Linear(dim, dim)
|
||||
self.linear5 = torch.nn.Linear(dim, dim)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
l1 = self.linear1(x)
|
||||
l2 = self.linear2(x)
|
||||
l3 = self.linear3(l1)
|
||||
l4 = self.linear4(l2)
|
||||
l5 = self.linear5(l3)
|
||||
return l4, l5
|
||||
model = MLP(4)
|
||||
gm = symbolic_trace(model)
|
||||
print(gm.graph)
|
||||
assign_bfs_level_to_nodes(gm.graph)
|
||||
for node in gm.graph.nodes:
|
||||
if hasattr(node, 'bfs_level'):
|
||||
print(node.name, node.bfs_level)
|
||||
|
||||
Output:
|
||||
graph():
|
||||
%x : [#users=2] = placeholder[target=x]
|
||||
%linear1 : [#users=1] = call_module[target=linear1](args = (%x,), kwargs = {})
|
||||
%linear2 : [#users=1] = call_module[target=linear2](args = (%x,), kwargs = {})
|
||||
%linear3 : [#users=1] = call_module[target=linear3](args = (%linear1,), kwargs = {})
|
||||
%linear4 : [#users=1] = call_module[target=linear4](args = (%linear2,), kwargs = {})
|
||||
%linear5 : [#users=1] = call_module[target=linear5](args = (%linear3,), kwargs = {})
|
||||
return (linear4, linear5)
|
||||
linear1 0
|
||||
linear2 0
|
||||
linear3 1
|
||||
linear4 1
|
||||
linear5 2
|
||||
"""
|
||||
current_level = 0
|
||||
nodes_to_process = []
|
||||
|
||||
top_nodes = get_top(graph)
|
||||
for node in top_nodes:
|
||||
node.bfs_level = current_level
|
||||
nodes_to_process.extend(get_all_consumers(graph, node))
|
||||
|
||||
current_level += 1
|
||||
while nodes_to_process:
|
||||
new_process_list = []
|
||||
for node in nodes_to_process:
|
||||
if node.op == 'output':
|
||||
continue
|
||||
node.bfs_level = current_level
|
||||
new_process_list.extend(get_all_consumers(graph, node))
|
||||
nodes_to_process = new_process_list
|
||||
current_level += 1
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
import colossalai
|
||||
import torch
|
||||
from colossalai.fx.passes.utils import get_leaf, get_top, assign_bfs_level_to_nodes
|
||||
from colossalai.fx import ColoTracer
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(dim, dim)
|
||||
self.linear2 = torch.nn.Linear(dim, dim)
|
||||
self.linear3 = torch.nn.Linear(dim, dim)
|
||||
self.linear4 = torch.nn.Linear(dim, dim)
|
||||
self.linear5 = torch.nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
l1 = self.linear1(x)
|
||||
l2 = self.linear2(x)
|
||||
l3 = self.linear3(l1)
|
||||
l4 = self.linear4(l2)
|
||||
l5 = self.linear5(l3)
|
||||
return l4, l5
|
||||
|
||||
|
||||
def test_graph_manipulation():
|
||||
model = MLP(4)
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model)
|
||||
nodes = list(graph.nodes)
|
||||
x, l1, l2, l3, l4, l5, output = nodes
|
||||
|
||||
leaf_nodes = set(get_leaf(graph))
|
||||
top_nodes = set(get_top(graph))
|
||||
compare_dict = {x: None, l1: 0, l2: 0, l3: 1, l4: 1, l5: 2, output: None}
|
||||
assign_bfs_level_to_nodes(graph)
|
||||
|
||||
assert leaf_nodes == set([l4, l5])
|
||||
assert top_nodes == set([l1, l2])
|
||||
for node in graph.nodes:
|
||||
if node.op in ('placeholder', 'output'):
|
||||
assert not hasattr(node, 'bfs_level')
|
||||
else:
|
||||
assert node.bfs_level == compare_dict[node]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_graph_manipulation()
|
Loading…
Reference in New Issue