diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py index 9c77590ff..3a3e5ddbf 100644 --- a/colossalai/fx/passes/adding_split_node_pass.py +++ b/colossalai/fx/passes/adding_split_node_pass.py @@ -10,6 +10,7 @@ def pipe_split(): def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): + # TODO(lyl): balanced policy V2, split module by node size(weight+bias+output) mod_graph = gm.graph total_param_amount = 0 for param in mod_graph.owning_module.parameters(): @@ -68,6 +69,9 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int): def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule): + # TODO(lyl): use partition IR to assign partition ID to each node. + # Currently: analyzing graph -> annotate graph by inserting split node -> use split module pass to split graph + # In future: graph to partitions -> analyzing partition IR -> recombining partitions to get best performance -> assign partition ID to each node part_idx = 0 def split_callback(n: torch.fx.Node): diff --git a/colossalai/fx/passes/utils.py b/colossalai/fx/passes/utils.py index 4dfb292e2..d3e38c190 100644 --- a/colossalai/fx/passes/utils.py +++ b/colossalai/fx/passes/utils.py @@ -1,10 +1,12 @@ import torch from typing import Dict, Set from torch.fx.node import Node, map_arg +from torch.fx.graph import Graph def get_comm_size(prev_partition, next_partition): - """Given two partitions (parent and child), + """ + Given two partitions (parent and child), calculate the communication size between the two. """ # Keep tracking the communication size between parent and child @@ -25,3 +27,136 @@ def get_comm_size(prev_partition, next_partition): comm_size += n.meta['tensor_meta'].numel visited_nodes.add(n) return comm_size + + +def get_leaf(graph: Graph): + """ + Given a graph, return leaf nodes of this graph. + + Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph, + we will get a normal DAG. Leaf nodes in this context means leaf nodes in that DAG. + """ + input_nodes: Dict[Node, None] = {} + for node in graph.nodes: + if node.op == 'output': + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + placeholder_nodes = [] + for node in input_nodes.keys(): + if node.op == 'placeholder': + placeholder_nodes.append(node) + for node in placeholder_nodes: + input_nodes.pop(node) + return list(input_nodes.keys()) + + +def is_leaf(graph: Graph, node: Node): + return node in get_leaf(graph) + + +def get_top(graph: Graph): + """ + Given a graph, return top nodes of this graph. + + Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph, + we will get a normal DAG. Top nodes in this context means nodes with BFS level 0 in that DAG. + """ + top_node_list = set() + for node in graph.nodes: + if node.op == 'output': + continue + is_top = False + + def _get_top(node): + nonlocal is_top + if node.op == 'placeholder': + is_top = True + + map_arg(node.args, lambda n: _get_top(n)) + map_arg(node.kwargs, lambda n: _get_top(n)) + if is_top: + top_node_list.add(node) + return list(top_node_list) + + +def is_top(graph: Graph, node: Node): + return node in get_top(graph) + + +def get_all_consumers(graph: Graph, node: Node): + """ + Given a graph and a node of this graph, return all consumers of the node. + + Returns: + List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``. + """ + consumer_list = [] + for n in graph.nodes: + if node in n.all_input_nodes: + consumer_list.append(n) + return consumer_list + + +def assign_bfs_level_to_nodes(graph: Graph): + """ + Give a graph, assign bfs level to each node of this graph excluding ``placeholder`` and ``output`` nodes. + + Example: + class MLP(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.linear3 = torch.nn.Linear(dim, dim) + self.linear4 = torch.nn.Linear(dim, dim) + self.linear5 = torch.nn.Linear(dim, dim) + + + def forward(self, x): + l1 = self.linear1(x) + l2 = self.linear2(x) + l3 = self.linear3(l1) + l4 = self.linear4(l2) + l5 = self.linear5(l3) + return l4, l5 + model = MLP(4) + gm = symbolic_trace(model) + print(gm.graph) + assign_bfs_level_to_nodes(gm.graph) + for node in gm.graph.nodes: + if hasattr(node, 'bfs_level'): + print(node.name, node.bfs_level) + + Output: + graph(): + %x : [#users=2] = placeholder[target=x] + %linear1 : [#users=1] = call_module[target=linear1](args = (%x,), kwargs = {}) + %linear2 : [#users=1] = call_module[target=linear2](args = (%x,), kwargs = {}) + %linear3 : [#users=1] = call_module[target=linear3](args = (%linear1,), kwargs = {}) + %linear4 : [#users=1] = call_module[target=linear4](args = (%linear2,), kwargs = {}) + %linear5 : [#users=1] = call_module[target=linear5](args = (%linear3,), kwargs = {}) + return (linear4, linear5) + linear1 0 + linear2 0 + linear3 1 + linear4 1 + linear5 2 + """ + current_level = 0 + nodes_to_process = [] + + top_nodes = get_top(graph) + for node in top_nodes: + node.bfs_level = current_level + nodes_to_process.extend(get_all_consumers(graph, node)) + + current_level += 1 + while nodes_to_process: + new_process_list = [] + for node in nodes_to_process: + if node.op == 'output': + continue + node.bfs_level = current_level + new_process_list.extend(get_all_consumers(graph, node)) + nodes_to_process = new_process_list + current_level += 1 diff --git a/tests/test_fx/test_graph_manipulation.py b/tests/test_fx/test_graph_manipulation.py new file mode 100644 index 000000000..fb33e58a7 --- /dev/null +++ b/tests/test_fx/test_graph_manipulation.py @@ -0,0 +1,50 @@ +import colossalai +import torch +from colossalai.fx.passes.utils import get_leaf, get_top, assign_bfs_level_to_nodes +from colossalai.fx import ColoTracer +from torch.fx import GraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.linear3 = torch.nn.Linear(dim, dim) + self.linear4 = torch.nn.Linear(dim, dim) + self.linear5 = torch.nn.Linear(dim, dim) + + def forward(self, x): + l1 = self.linear1(x) + l2 = self.linear2(x) + l3 = self.linear3(l1) + l4 = self.linear4(l2) + l5 = self.linear5(l3) + return l4, l5 + + +def test_graph_manipulation(): + model = MLP(4) + tracer = ColoTracer() + graph = tracer.trace(model) + nodes = list(graph.nodes) + x, l1, l2, l3, l4, l5, output = nodes + + leaf_nodes = set(get_leaf(graph)) + top_nodes = set(get_top(graph)) + compare_dict = {x: None, l1: 0, l2: 0, l3: 1, l4: 1, l5: 2, output: None} + assign_bfs_level_to_nodes(graph) + + assert leaf_nodes == set([l4, l5]) + assert top_nodes == set([l1, l2]) + for node in graph.nodes: + if node.op in ('placeholder', 'output'): + assert not hasattr(node, 'bfs_level') + else: + assert node.bfs_level == compare_dict[node] + + +if __name__ == '__main__': + test_graph_manipulation()