ColossalAI/colossalai/fx/passes/utils.py

28 lines
1.1 KiB
Python
Raw Normal View History

import torch
from typing import Dict, Set
from torch.fx.node import Node, map_arg
def get_comm_size(prev_partition, next_partition):
"""Given two partitions (parent and child),
calculate the communication size between the two.
"""
# Keep tracking the communication size between parent and child
comm_size = 0
# Keep tracking all the counted node
visited_nodes = set()
# Go through all nodes in the child partition
# If a node has input nodes from the parent partition,
# the output size of those input nodes will be counted
# and added to comm_size
parent_node_names = [n.name for n in prev_partition.graph.nodes]
for node in next_partition.graph.nodes:
input_nodes: Dict[Node, None] = {}
map_arg(node.args, lambda n: input_nodes.setdefault(n))
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
for n in input_nodes:
if n.name in parent_node_names and n not in visited_nodes:
comm_size += n.meta['tensor_meta'].numel
visited_nodes.add(n)
return comm_size