ColossalAI/tests/test_fx/test_pipeline/test_topo/topo_utils.py

92 lines
3.0 KiB
Python
Raw Normal View History

import torch
from torch.fx import GraphModule
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from colossalai.fx import ColoTracer
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
from colossalai.pipeline.middleware.adaptor import get_fx_topology
import random
import numpy as np
MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
class MLP(torch.nn.Module):
def __init__(self, config={}):
super().__init__()
dim = config['dim']
layers = config['layers']
self.layers = torch.nn.ModuleList()
for _ in range(layers):
self.layers.append(torch.nn.Linear(dim, dim, bias=False))
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def split_model_and_get_DAG(model, data_gen):
model.eval()
# generate input sample
kwargs = data_gen()
# tracing model
tracer = ColoTracer()
try:
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# apply transform passes
annotated_model = balanced_split_pass(gm, 2)
top_module, split_submodules = split_with_split_nodes_pass(annotated_model)
topo = get_fx_topology(top_module)
for submodule in split_submodules:
if isinstance(submodule, torch.fx.GraphModule):
setattr(submodule, '_topo', topo)
return top_module, split_submodules[0]._topo
def check_input(top_module, input_partition: Partition):
partition_output = input_partition.get_output_vals()
arg_pos = 0
for node in top_module.graph.nodes:
if node.op == 'placeholder':
cur_checkee = partition_output[arg_pos]
to_partition_and_offset = cur_checkee.get()
assert len(to_partition_and_offset) == len(node.users.keys())
arg_pos += 1
assert arg_pos == len(partition_output)
def check_submod(top_module, part_id, mid_partition: Partition):
partition_input = mid_partition.get_input_vals()
partition_output = mid_partition.get_output_vals()
cnt = 1
cur_node = None
for node in top_module.graph.nodes:
if node.name.startswith('submod'):
cnt += 1
if cnt == part_id:
cur_node = node
break
assert len(partition_input) == len(cur_node.args)
assert len(partition_output) == len(cur_node.users)
def check_topo(top_module, topo: Topo):
input_partition = topo.get_input_partition()
mid_partitions = topo.get_mid_partitions()
check_input(top_module, input_partition)
for part_id, submod in mid_partitions.items():
check_submod(top_module, part_id, submod)