mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
172 lines
6.0 KiB
172 lines
6.0 KiB
import torch |
|
from typing import Dict |
|
from torch.fx.node import Node, map_arg |
|
from torch.fx.graph import Graph |
|
|
|
def get_comm_size(prev_partition, next_partition): |
|
""" |
|
Given two partitions (parent and child), |
|
calculate the communication size between the two. |
|
""" |
|
# Keep tracking the communication size between parent and child |
|
comm_size = 0 |
|
# Keep tracking all the counted node |
|
visited_nodes = set() |
|
# Go through all nodes in the child partition |
|
# If a node has input nodes from the parent partition, |
|
# the output size of those input nodes will be counted |
|
# and added to comm_size |
|
parent_node_names = [n.name for n in prev_partition.graph.nodes] |
|
for node in next_partition.graph.nodes: |
|
input_nodes: Dict[Node, None] = {} |
|
map_arg(node.args, lambda n: input_nodes.setdefault(n)) |
|
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) |
|
for n in input_nodes: |
|
if n.name in parent_node_names and n not in visited_nodes: |
|
comm_size += n.meta['tensor_meta'].numel |
|
visited_nodes.add(n) |
|
return comm_size |
|
|
|
|
|
def get_leaf(graph: Graph): |
|
""" |
|
Given a graph, return leaf nodes of this graph. |
|
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph, |
|
we will get a normal DAG. Leaf nodes in this context means leaf nodes in that DAG. |
|
""" |
|
input_nodes: Dict[Node, None] = {} |
|
for node in graph.nodes: |
|
if node.op == 'output': |
|
map_arg(node.args, lambda n: input_nodes.setdefault(n)) |
|
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) |
|
placeholder_nodes = [] |
|
for node in input_nodes.keys(): |
|
if node.op == 'placeholder': |
|
placeholder_nodes.append(node) |
|
for node in placeholder_nodes: |
|
input_nodes.pop(node) |
|
return list(input_nodes.keys()) |
|
|
|
|
|
def is_leaf(graph: Graph, node: Node): |
|
return node in get_leaf(graph) |
|
|
|
|
|
def get_top(graph: Graph): |
|
""" |
|
Given a graph, return top nodes of this graph. |
|
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph, |
|
we will get a normal DAG. Top nodes in this context means nodes with BFS level 0 in that DAG. |
|
""" |
|
top_node_list = set() |
|
for node in graph.nodes: |
|
if node.op == 'output': |
|
continue |
|
is_top = False |
|
|
|
def _get_top(node): |
|
nonlocal is_top |
|
if node.op == 'placeholder': |
|
is_top = True |
|
|
|
map_arg(node.args, lambda n: _get_top(n)) |
|
map_arg(node.kwargs, lambda n: _get_top(n)) |
|
if is_top: |
|
top_node_list.add(node) |
|
return list(top_node_list) |
|
|
|
|
|
def is_top(graph: Graph, node: Node): |
|
return node in get_top(graph) |
|
|
|
|
|
def get_all_consumers(graph: Graph, node: Node): |
|
""" |
|
Given a graph and a node of this graph, return all consumers of the node. |
|
|
|
Returns: |
|
List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``. |
|
""" |
|
consumer_list = [] |
|
for n in graph.nodes: |
|
if node in n.all_input_nodes: |
|
consumer_list.append(n) |
|
return consumer_list |
|
|
|
|
|
def assign_bfs_level_to_nodes(graph: Graph): |
|
""" |
|
Give a graph, assign bfs level to each node of this graph excluding ``placeholder`` and ``output`` nodes. |
|
Example: |
|
class MLP(torch.nn.Module): |
|
def __init__(self, dim: int): |
|
super().__init__() |
|
self.linear1 = torch.nn.Linear(dim, dim) |
|
self.linear2 = torch.nn.Linear(dim, dim) |
|
self.linear3 = torch.nn.Linear(dim, dim) |
|
self.linear4 = torch.nn.Linear(dim, dim) |
|
self.linear5 = torch.nn.Linear(dim, dim) |
|
def forward(self, x): |
|
l1 = self.linear1(x) |
|
l2 = self.linear2(x) |
|
l3 = self.linear3(l1) |
|
l4 = self.linear4(l2) |
|
l5 = self.linear5(l3) |
|
return l4, l5 |
|
model = MLP(4) |
|
gm = symbolic_trace(model) |
|
print(gm.graph) |
|
assign_bfs_level_to_nodes(gm.graph) |
|
for node in gm.graph.nodes: |
|
if hasattr(node, 'bfs_level'): |
|
print(node.name, node.bfs_level) |
|
|
|
Output: |
|
graph(): |
|
%x : [#users=2] = placeholder[target=x] |
|
%linear1 : [#users=1] = call_module[target=linear1](args = (%x,), kwargs = {}) |
|
%linear2 : [#users=1] = call_module[target=linear2](args = (%x,), kwargs = {}) |
|
%linear3 : [#users=1] = call_module[target=linear3](args = (%linear1,), kwargs = {}) |
|
%linear4 : [#users=1] = call_module[target=linear4](args = (%linear2,), kwargs = {}) |
|
%linear5 : [#users=1] = call_module[target=linear5](args = (%linear3,), kwargs = {}) |
|
return (linear4, linear5) |
|
linear1 0 |
|
linear2 0 |
|
linear3 1 |
|
linear4 1 |
|
linear5 2 |
|
""" |
|
current_level = 0 |
|
nodes_to_process = [] |
|
|
|
top_nodes = get_top(graph) |
|
for node in top_nodes: |
|
node.bfs_level = current_level |
|
nodes_to_process.extend(get_all_consumers(graph, node)) |
|
|
|
current_level += 1 |
|
while nodes_to_process: |
|
new_process_list = [] |
|
for node in nodes_to_process: |
|
if node.op == 'output': |
|
continue |
|
node.bfs_level = current_level |
|
new_process_list.extend(get_all_consumers(graph, node)) |
|
nodes_to_process = new_process_list |
|
current_level += 1 |
|
|
|
|
|
def get_node_module(node) -> torch.nn.Module: |
|
""" |
|
Find the module associated with the given node. |
|
Args: |
|
node (torch.fx.Node): a torch.fx.Node object in the fx computation graph |
|
Returns: |
|
torch.nn.Module: the module associated with the given node |
|
""" |
|
|
|
assert node.graph.owning_module is not None, 'Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object' |
|
assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}' |
|
module = node.graph.owning_module.get_submodule(node.target) |
|
return module |
|
|
|
|