2022-07-07 08:22:00 +00:00
|
|
|
import torch
|
|
|
|
from typing import Dict, Set
|
|
|
|
from torch.fx.node import Node, map_arg
|
|
|
|
|
|
|
|
|
|
|
|
def get_comm_size(prev_partition, next_partition):
|
|
|
|
"""Given two partitions (parent and child),
|
|
|
|
calculate the communication size between the two.
|
|
|
|
"""
|
|
|
|
# Keep tracking the communication size between parent and child
|
|
|
|
comm_size = 0
|
|
|
|
# Keep tracking all the counted node
|
|
|
|
visited_nodes = set()
|
|
|
|
# Go through all nodes in the child partition
|
|
|
|
# If a node has input nodes from the parent partition,
|
|
|
|
# the output size of those input nodes will be counted
|
|
|
|
# and added to comm_size
|
2022-07-08 02:54:41 +00:00
|
|
|
parent_node_names = [n.name for n in prev_partition.graph.nodes]
|
|
|
|
for node in next_partition.graph.nodes:
|
2022-07-07 08:22:00 +00:00
|
|
|
input_nodes: Dict[Node, None] = {}
|
|
|
|
map_arg(node.args, lambda n: input_nodes.setdefault(n))
|
|
|
|
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
|
|
|
|
for n in input_nodes:
|
|
|
|
if n.name in parent_node_names and n not in visited_nodes:
|
|
|
|
comm_size += n.meta['tensor_meta'].numel
|
|
|
|
visited_nodes.add(n)
|
|
|
|
return comm_size
|