import torch from typing import Dict, Set from torch.fx.node import Node, map_arg from torch.fx.graph import Graph def get_comm_size(prev_partition, next_partition): """ Given two partitions (parent and child), calculate the communication size between the two. """ # Keep tracking the communication size between parent and child comm_size = 0 # Keep tracking all the counted node visited_nodes = set() # Go through all nodes in the child partition # If a node has input nodes from the parent partition, # the output size of those input nodes will be counted # and added to comm_size parent_node_names = [n.name for n in prev_partition.graph.nodes] for node in next_partition.graph.nodes: input_nodes: Dict[Node, None] = {} map_arg(node.args, lambda n: input_nodes.setdefault(n)) map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) for n in input_nodes: if n.name in parent_node_names and n not in visited_nodes: comm_size += n.meta['tensor_meta'].numel visited_nodes.add(n) return comm_size def get_leaf(graph: Graph): """ Given a graph, return leaf nodes of this graph. Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph, we will get a normal DAG. Leaf nodes in this context means leaf nodes in that DAG. """ input_nodes: Dict[Node, None] = {} for node in graph.nodes: if node.op == 'output': map_arg(node.args, lambda n: input_nodes.setdefault(n)) map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) placeholder_nodes = [] for node in input_nodes.keys(): if node.op == 'placeholder': placeholder_nodes.append(node) for node in placeholder_nodes: input_nodes.pop(node) return list(input_nodes.keys()) def is_leaf(graph: Graph, node: Node): return node in get_leaf(graph) def get_top(graph: Graph): """ Given a graph, return top nodes of this graph. Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph, we will get a normal DAG. Top nodes in this context means nodes with BFS level 0 in that DAG. """ top_node_list = set() for node in graph.nodes: if node.op == 'output': continue is_top = False def _get_top(node): nonlocal is_top if node.op == 'placeholder': is_top = True map_arg(node.args, lambda n: _get_top(n)) map_arg(node.kwargs, lambda n: _get_top(n)) if is_top: top_node_list.add(node) return list(top_node_list) def is_top(graph: Graph, node: Node): return node in get_top(graph) def get_all_consumers(graph: Graph, node: Node): """ Given a graph and a node of this graph, return all consumers of the node. Returns: List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``. """ consumer_list = [] for n in graph.nodes: if node in n.all_input_nodes: consumer_list.append(n) return consumer_list def assign_bfs_level_to_nodes(graph: Graph): """ Give a graph, assign bfs level to each node of this graph excluding ``placeholder`` and ``output`` nodes. Example: class MLP(torch.nn.Module): def __init__(self, dim: int): super().__init__() self.linear1 = torch.nn.Linear(dim, dim) self.linear2 = torch.nn.Linear(dim, dim) self.linear3 = torch.nn.Linear(dim, dim) self.linear4 = torch.nn.Linear(dim, dim) self.linear5 = torch.nn.Linear(dim, dim) def forward(self, x): l1 = self.linear1(x) l2 = self.linear2(x) l3 = self.linear3(l1) l4 = self.linear4(l2) l5 = self.linear5(l3) return l4, l5 model = MLP(4) gm = symbolic_trace(model) print(gm.graph) assign_bfs_level_to_nodes(gm.graph) for node in gm.graph.nodes: if hasattr(node, 'bfs_level'): print(node.name, node.bfs_level) Output: graph(): %x : [#users=2] = placeholder[target=x] %linear1 : [#users=1] = call_module[target=linear1](args = (%x,), kwargs = {}) %linear2 : [#users=1] = call_module[target=linear2](args = (%x,), kwargs = {}) %linear3 : [#users=1] = call_module[target=linear3](args = (%linear1,), kwargs = {}) %linear4 : [#users=1] = call_module[target=linear4](args = (%linear2,), kwargs = {}) %linear5 : [#users=1] = call_module[target=linear5](args = (%linear3,), kwargs = {}) return (linear4, linear5) linear1 0 linear2 0 linear3 1 linear4 1 linear5 2 """ current_level = 0 nodes_to_process = [] top_nodes = get_top(graph) for node in top_nodes: node.bfs_level = current_level nodes_to_process.extend(get_all_consumers(graph, node)) current_level += 1 while nodes_to_process: new_process_list = [] for node in nodes_to_process: if node.op == 'output': continue node.bfs_level = current_level new_process_list.extend(get_all_consumers(graph, node)) nodes_to_process = new_process_list current_level += 1