[fx]add split module pass and unit test from pipeline passes (#1242)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [fx]add split module pass and unit test from pipeline passes

* fix MNASNet bug

* polish
pull/1253/head
YuliangLiu0306 2022-07-12 13:45:01 +08:00 committed by GitHub
parent 762905da68
commit 30b4fc0eb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 702 additions and 3 deletions

View File

@ -2,7 +2,7 @@ import torch
from torch.fx import symbolic_trace
from torch.fx.node import Node
from torch.fx.passes.split_module import split_module
from colossalai.fx.passes.split_module import split_module
def pipe_split():
@ -26,8 +26,14 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
if accumulate_param_amount >= params_per_partition:
accumulate_param_amount = 0
pp_size -= 1
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
# If the next node is output node, we will insert split annotation before
# node to make sure there is at least one node in last partition.
if node.next.op == 'output':
with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split)
else:
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile()
return gm

View File

@ -0,0 +1,277 @@
import torch
from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Any, Optional
from torch.fx._compatibility import compatibility
import inspect
@compatibility(is_backward_compatible=True)
class Partition:
"""
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
"""
def __init__(self, name: str):
self.name: str = name
self.node_names: List[str] = []
self.inputs: Dict[str, None] = {}
self.outputs: Dict[str, None] = {}
self.partitions_dependent_on: Dict[str, None] = {}
self.partition_dependents: Dict[str, None] = {}
self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
self.environment: Dict[torch.fx.node.Node, torch.fx.node.Node] = {}
self.targets: Dict[str, Any] = {}
def __repr__(self) -> str:
return f"name: {self.name},\n" \
f" nodes: {self.node_names},\n" \
f" inputs: {self.inputs},\n" \
f" outputs: {self.outputs},\n" \
f" partitions depenent on: {self.partitions_dependent_on},\n" \
f" parition dependents: {self.partition_dependents}"
# Creates subgraphs out of main graph
@compatibility(is_backward_compatible=True)
def split_module(
m: GraphModule,
root_m: torch.nn.Module,
split_callback: Callable[[torch.fx.node.Node], int],
):
"""
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
Creates subgraphs out of main graph
Args:
m (GraphModule): Graph module to split
root_m (torch.nn.Module): root nn module. Not currently used. Included
because the root nn module is usually transformed via
torch.fx._symbolic_trace.symbolic_trace (see example below)
split_callback (Callable[[torch.fx.node.Node], int]): Callable function
that maps a given Node instance to a numeric partition identifier.
split_module will use this function as the policy for which operations
appear in which partitions in the output Module.
Returns:
GraphModule: the module after split.
Example:
This is a sample setup:
import torch
from torch.fx.symbolic_trace import symbolic_trace
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x, y):
z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
w = self.linear(y).clamp(min=0.0, max=1.0)
return z + w
# symbolically trace model
my_module = MyModule()
my_module_traced = symbolic_trace(my_module)
# random mod partitioning
partition_counter = 0
NPARTITIONS = 3
def mod_partition(node: Node):
global partition_counter
partition = partition_counter % NPARTITIONS
partition_counter = (partition_counter + 1) % NPARTITIONS
return partition
# split module in module with submodules
module_with_submodules = split_module(
my_module_traced, my_module, mod_partition
)
Output looks like this. Original graph is broken into partitions
> print(module_with_submodules)
GraphModule(
(submod_0): GraphModule(
(linear): Linear(in_features=4, out_features=5, bias=True)
)
(submod_1): GraphModule(
(linear): Linear(in_features=4, out_features=5, bias=True)
)
(submod_2): GraphModule()
)
def forward(self, x, y):
param = self.param
submod_0 = self.submod_0(x, param, y); x = param = y = None
getitem = submod_0[0]
getitem_1 = submod_0[1]; submod_0 = None
submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None
getitem_2 = submod_1[0]
getitem_3 = submod_1[1]; submod_1 = None
submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None
return submod_2
Output of split module is the same as output of input traced module.
This is an example within a test setting:
> orig_out = my_module_traced(x, y)
> submodules_out = module_with_submodules(x, y)
> self.assertEqual(orig_out, submodules_out)
True
"""
partitions: Dict[str, Partition] = {}
orig_nodes: Dict[str, torch.fx.node.Node] = {}
def record_cross_partition_use(def_node: torch.fx.node.Node,
use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, '_fx_partition', None)
use_partition_name = getattr(use_node, '_fx_partition', None)
if def_partition_name != use_partition_name:
if def_partition_name is not None:
def_partition = partitions[def_partition_name]
def_partition.outputs.setdefault(def_node.name)
if use_partition_name is not None:
def_partition.partition_dependents.setdefault(use_partition_name)
if use_partition_name is not None:
use_partition = partitions[use_partition_name]
use_partition.inputs.setdefault(def_node.name)
if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name)
# split nodes into parititons
for node in m.graph.nodes:
orig_nodes[node.name] = node
if node.op in ["placeholder"]:
continue
if node.op == 'output':
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None))
continue
partition_name = str(split_callback(node))
# add node to partitions
partition = partitions.get(partition_name)
if partition is None:
partitions[partition_name] = partition = Partition(partition_name)
partition.node_names.append(node.name)
node._fx_partition = partition_name
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
# find partitions with no dependencies
root_partitions: List[str] = []
for partition_name, partition in partitions.items():
if not len(partition.partitions_dependent_on):
root_partitions.append(partition_name)
# check partitions for circular dependencies and create topological partition ordering
sorted_partitions: List[str] = []
while root_partitions:
root_partition = root_partitions.pop()
sorted_partitions.append(root_partition)
for dependent in partitions[root_partition].partition_dependents:
partitions[dependent].partitions_dependent_on.pop(root_partition)
if not partitions[dependent].partitions_dependent_on:
root_partitions.append(dependent)
if len(sorted_partitions) != len(partitions):
raise RuntimeError("cycle exists between partitions!")
# add placeholders to parititons
for partition_name in sorted_partitions:
partition = partitions[partition_name]
for input in partition.inputs:
placeholder = partition.graph.placeholder(input)
placeholder.meta = orig_nodes[input].meta.copy()
partition.environment[orig_nodes[input]] = placeholder
# Transform nodes and collect targets for partition's submodule
for node in m.graph.nodes:
if hasattr(node, '_fx_partition'):
partition = partitions[node._fx_partition]
# swap out old graph nodes in kw/args with references to new nodes in this submodule
environment = partition.environment
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
if node.op not in ['call_module', 'get_attr']:
target = node.target
else:
target_atoms = node.target.split('.')
target_attr = m
for atom in target_atoms:
if not hasattr(target_attr, atom):
raise RuntimeError(f'Operator target {node.target} not found!')
target_attr = getattr(target_attr, atom)
# target = target_atoms[-1]
target = '_'.join(target_atoms)
partition.targets[target] = target_attr
assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict)
new_node = partition.graph.create_node(op=node.op,
target=target,
args=gathered_args,
kwargs=gathered_kwargs)
new_node.meta = node.meta.copy()
partition.environment[node] = new_node
# Set up values to construct base module
base_mod_env: Dict[str, torch.fx.node.Node] = {}
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
if node.op == 'placeholder':
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
type_expr=node.type,
default_value=default_value)
base_mod_env[node.name].meta = node.meta.copy()
# Do some things iterating over the partitions in topological order again:
# 1) Finish off submodule Graphs by setting corresponding outputs
# 2) Construct GraphModules for each submodule
# 3) Construct the base graph by emitting calls to those submodules in
# topological order
for partition_name in sorted_partitions:
partition = partitions[partition_name]
# Set correct output values
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
partition.graph.output(output_vals)
# Construct GraphModule for this partition
submod_name = f'submod_{partition_name}'
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
partition.graph) # noqa: B950
# Emit call in base graph to this submodule
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
if len(partition.outputs) > 1:
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
if not partition.outputs:
continue
base_mod_env[list(partition.outputs)[0]] = output_val
for node in m.graph.nodes:
if node.op == 'output':
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)

View File

@ -0,0 +1,69 @@
import torch
from torch.fx import symbolic_trace
from torch.fx import GraphModule
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from colossalai.fx import ColoTracer
import inspect
import random
import numpy as np
MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
def split_model_and_compare_output(model, data_gen):
model.eval()
# generate input sample
kwargs = data_gen()
# get origin output and rng state
cpu_rng_state = torch.get_rng_state()
output = model(**kwargs)
# tracing model
tracer = ColoTracer()
try:
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# apply transform passes
annotated_model = balanced_split_pass(gm, 2)
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
# get split model
model_part0 = list(split_model.children())[0]
model_part1 = list(split_model.children())[1]
# set rng state and compute output of split model
torch.set_rng_state(cpu_rng_state)
output_part0 = model_part0(**kwargs)
sig = inspect.signature(model_part1.forward)
if isinstance(output_part0, torch.Tensor):
output_part1 = model_part1(output_part0)
else:
if len(output_part0) > len(sig.parameters):
output_part0 = output_part0[:len(sig.parameters)]
output_part1 = model_part1(*output_part0)
# get output tensor from HFOutput datastructure
if 'logits' in output:
output_to_compare = output['logits']
elif 'prediction_logits' in output:
output_to_compare = output['prediction_logits']
else:
output_to_compare = output['last_hidden_state']
# compare output
if isinstance(output_part1, torch.Tensor):
assert output_to_compare.equal(output_part1)
elif isinstance(output_part1, (tuple, list)):
assert output_to_compare.equal(output_part1[0])
else:
assert False

View File

@ -0,0 +1,38 @@
import transformers
import torch
from hf_utils import split_model_and_compare_output
BATCH_SIZE = 2
SEQ_LENGHT = 16
def test_single_sentence_albert():
MODEL_LIST = [
transformers.AlbertModel,
transformers.AlbertForPreTraining,
transformers.AlbertForMaskedLM,
transformers.AlbertForSequenceClassification,
transformers.AlbertForTokenClassification,
]
config = transformers.AlbertConfig(vocab_size=100,
embedding_size=128,
hidden_size=128,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=256)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return meta_args
for model_cls in MODEL_LIST:
model = model_cls(config=config)
split_model_and_compare_output(model, data_gen)
if __name__ == '__main__':
test_single_sentence_albert()

View File

@ -0,0 +1,38 @@
import transformers
import torch
from hf_utils import split_model_and_compare_output
BATCH_SIZE = 2
SEQ_LENGHT = 16
def test_single_sentence_bert():
MODEL_LIST = [
transformers.BertModel,
transformers.BertForPreTraining,
transformers.BertLMHeadModel,
transformers.BertForMaskedLM,
transformers.BertForSequenceClassification,
transformers.BertForTokenClassification,
]
config = transformers.BertConfig(vocab_size=100,
hidden_size=128,
num_hidden_layers=4,
num_attention_heads=4,
intermediate_size=256)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return meta_args
for model_cls in MODEL_LIST:
model = model_cls(config=config)
split_model_and_compare_output(model, data_gen)
if __name__ == '__main__':
test_single_sentence_bert()

View File

@ -0,0 +1,34 @@
import transformers
import torch
from hf_utils import split_model_and_compare_output
BATCH_SIZE = 64
SEQ_LENGHT = 16
NUM_EPOCHS = 2
NUM_CHUNKS = 1
def test_gpt():
MODEL_LIST = [
transformers.GPT2Model,
transformers.GPT2LMHeadModel,
transformers.GPT2DoubleHeadsModel,
transformers.GPT2ForTokenClassification,
# transformers.GPT2ForSequenceClassification, # not supported yet
]
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=8)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return kwargs
for model_cls in MODEL_LIST:
model = model_cls(config=config)
split_model_and_compare_output(model, data_gen)
if __name__ == '__main__':
test_gpt()

View File

@ -0,0 +1,30 @@
import pytest
import transformers
import torch
from hf_utils import split_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGHT = 16
def test_opt():
MODEL_LIST = [
transformers.OPTModel,
transformers.OPTForCausalLM,
]
config = transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs
for model_cls in MODEL_LIST:
model = model_cls(config=config)
split_model_and_compare_output(model, data_gen)
if __name__ == '__main__':
test_opt()

View File

@ -0,0 +1,43 @@
import pytest
import transformers
import torch
from hf_utils import split_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGHT = 16
@pytest.mark.skip('tracing failed')
def test_t5():
MODEL_LIST = [
transformers.T5Model,
transformers.T5ForConditionalGeneration,
transformers.T5EncoderModel,
]
config = transformers.T5Config(d_model=128, num_layers=2)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
return kwargs
def data_gen_for_encoder_only():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids)
return kwargs
for model_cls in MODEL_LIST:
model = model_cls(config=config)
if isinstance(model, transformers.T5EncoderModel):
data_gen_func = data_gen_for_encoder_only
else:
data_gen_func = data_gen
split_model_and_compare_output(model, data_gen_func)
if __name__ == '__main__':
test_t5()

View File

@ -0,0 +1,51 @@
import torch
import pytest
try:
import timm.models as tm
except:
pass
from timm_utils import split_model_and_compare_output
@pytest.mark.skip('skip as timm is required')
def test_timm_models_without_control_flow():
MODEL_LIST = [
tm.resnest.resnest50d,
tm.beit.beit_base_patch16_224,
tm.cait.cait_s24_224,
tm.convmixer.convmixer_768_32,
tm.efficientnet.efficientnetv2_m,
tm.resmlp_12_224,
tm.vision_transformer.vit_base_patch16_224,
tm.deit_base_distilled_patch16_224,
]
data = torch.rand(2, 3, 224, 224)
for model_cls in MODEL_LIST:
model = model_cls()
split_model_and_compare_output(model, data)
@pytest.mark.skip('skip as timm is required')
def test_timm_models_with_control_flow():
torch.backends.cudnn.deterministic = True
MODEL_LIST_WITH_CONTROL_FLOW = [
tm.convnext.convnext_base, tm.vgg.vgg11, tm.dpn.dpn68, tm.densenet.densenet121, tm.rexnet.rexnet_100,
tm.swin_transformer.swin_base_patch4_window7_224
]
data = torch.rand(2, 3, 224, 224)
meta_args = {'x': data.to('meta')}
for model_cls in MODEL_LIST_WITH_CONTROL_FLOW:
model = model_cls()
split_model_and_compare_output(model, data, meta_args)
if __name__ == '__main__':
test_timm_models_without_control_flow()
test_timm_models_with_control_flow()

View File

@ -0,0 +1,51 @@
import torch
from torch.fx import symbolic_trace
from torch.fx import GraphModule
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from colossalai.fx import ColoTracer
import inspect
import random
import numpy as np
MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
torch.backends.cudnn.deterministic = True
def split_model_and_compare_output(model, data, meta_args=None):
model.eval()
# get origin output and rng state
cpu_rng_state = torch.get_rng_state()
output = model(data)
# tracing model
tracer = ColoTracer()
try:
graph = tracer.trace(root=model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# apply transform passes
annotated_model = balanced_split_pass(gm, 2)
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
# get split model
model_part0 = list(split_model.children())[0]
model_part1 = list(split_model.children())[1]
# set rng state and compute output of split model
torch.set_rng_state(cpu_rng_state)
output_part0 = model_part0(data)
sig = inspect.signature(model_part1.forward)
if isinstance(output_part0, torch.Tensor):
output_part1 = model_part1(output_part0)
else:
if len(output_part0) > len(sig.parameters):
output_part0 = output_part0[:len(sig.parameters)]
output_part1 = model_part1(*output_part0)
assert output.equal(output_part1)

View File

@ -0,0 +1,62 @@
import torch
try:
import torchvision.models as tm
except:
pass
from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from torch.fx import GraphModule
import random
import numpy as np
import inspect
MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
torch.backends.cudnn.deterministic = True
@pytest.mark.skip('skip as torchvision is required')
def test_torchvision_models():
MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.efficientnet_b0, tm.mnasnet0_5
]
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224)
for model_cls in MODEL_LIST:
model = model_cls()
model.eval()
cpu_rng_state = torch.get_rng_state()
output = model(data)
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# apply transform passes
annotated_model = balanced_split_pass(gm, 2)
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
# get split model
model_part0 = list(split_model.children())[0]
model_part1 = list(split_model.children())[1]
# set rng state and compute output of split model
torch.set_rng_state(cpu_rng_state)
output_part0 = model_part0(data)
sig = inspect.signature(model_part1.forward)
if isinstance(output_part0, torch.Tensor):
output_part1 = model_part1(output_part0)
else:
if len(output_part0) > len(sig.parameters):
output_part0 = output_part0[:len(sig.parameters)]
output_part1 = model_part1(*output_part0)
assert output.equal(output_part1)
if __name__ == '__main__':
test_torchvision_models()