You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_fx/test_graph_manipulation.py

51 lines
1.5 KiB

import colossalai
import torch
from colossalai.fx.passes.utils import get_leaf, get_top, assign_bfs_level_to_nodes
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
class MLP(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.linear1 = torch.nn.Linear(dim, dim)
self.linear2 = torch.nn.Linear(dim, dim)
self.linear3 = torch.nn.Linear(dim, dim)
self.linear4 = torch.nn.Linear(dim, dim)
self.linear5 = torch.nn.Linear(dim, dim)
def forward(self, x):
l1 = self.linear1(x)
l2 = self.linear2(x)
l3 = self.linear3(l1)
l4 = self.linear4(l2)
l5 = self.linear5(l3)
return l4, l5
def test_graph_manipulation():
model = MLP(4)
tracer = ColoTracer()
graph = tracer.trace(model)
nodes = list(graph.nodes)
x, l1, l2, l3, l4, l5, output = nodes
leaf_nodes = set(get_leaf(graph))
top_nodes = set(get_top(graph))
compare_dict = {x: None, l1: 0, l2: 0, l3: 1, l4: 1, l5: 2, output: None}
assign_bfs_level_to_nodes(graph)
assert leaf_nodes == set([l4, l5])
assert top_nodes == set([l1, l2])
for node in graph.nodes:
if node.op in ('placeholder', 'output'):
assert not hasattr(node, 'bfs_level')
else:
assert node.bfs_level == compare_dict[node]
if __name__ == '__main__':
test_graph_manipulation()