mirror of https://github.com/hpcaitech/ColossalAI
[fx]add split module pass and unit test from pipeline passes (#1242)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [fx]add split module pass and unit test from pipeline passes
* fix MNASNet bug
* polish
pull/1253/head
parent
762905da68
commit
30b4fc0eb0
|
@ -2,7 +2,7 @@ import torch
|
||||||
|
|
||||||
from torch.fx import symbolic_trace
|
from torch.fx import symbolic_trace
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
from torch.fx.passes.split_module import split_module
|
from colossalai.fx.passes.split_module import split_module
|
||||||
|
|
||||||
|
|
||||||
def pipe_split():
|
def pipe_split():
|
||||||
|
@ -26,8 +26,14 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||||
if accumulate_param_amount >= params_per_partition:
|
if accumulate_param_amount >= params_per_partition:
|
||||||
accumulate_param_amount = 0
|
accumulate_param_amount = 0
|
||||||
pp_size -= 1
|
pp_size -= 1
|
||||||
with mod_graph.inserting_after(node):
|
# If the next node is output node, we will insert split annotation before
|
||||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
# node to make sure there is at least one node in last partition.
|
||||||
|
if node.next.op == 'output':
|
||||||
|
with mod_graph.inserting_before(node):
|
||||||
|
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||||
|
else:
|
||||||
|
with mod_graph.inserting_after(node):
|
||||||
|
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,277 @@
|
||||||
|
import torch
|
||||||
|
from torch.fx.graph_module import GraphModule
|
||||||
|
from typing import Callable, List, Dict, Any, Optional
|
||||||
|
from torch.fx._compatibility import compatibility
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
|
class Partition:
|
||||||
|
"""
|
||||||
|
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name: str = name
|
||||||
|
self.node_names: List[str] = []
|
||||||
|
self.inputs: Dict[str, None] = {}
|
||||||
|
self.outputs: Dict[str, None] = {}
|
||||||
|
self.partitions_dependent_on: Dict[str, None] = {}
|
||||||
|
self.partition_dependents: Dict[str, None] = {}
|
||||||
|
self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
|
||||||
|
self.environment: Dict[torch.fx.node.Node, torch.fx.node.Node] = {}
|
||||||
|
self.targets: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"name: {self.name},\n" \
|
||||||
|
f" nodes: {self.node_names},\n" \
|
||||||
|
f" inputs: {self.inputs},\n" \
|
||||||
|
f" outputs: {self.outputs},\n" \
|
||||||
|
f" partitions depenent on: {self.partitions_dependent_on},\n" \
|
||||||
|
f" parition dependents: {self.partition_dependents}"
|
||||||
|
|
||||||
|
|
||||||
|
# Creates subgraphs out of main graph
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
|
def split_module(
|
||||||
|
m: GraphModule,
|
||||||
|
root_m: torch.nn.Module,
|
||||||
|
split_callback: Callable[[torch.fx.node.Node], int],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
|
||||||
|
Creates subgraphs out of main graph
|
||||||
|
|
||||||
|
Args:
|
||||||
|
m (GraphModule): Graph module to split
|
||||||
|
root_m (torch.nn.Module): root nn module. Not currently used. Included
|
||||||
|
because the root nn module is usually transformed via
|
||||||
|
torch.fx._symbolic_trace.symbolic_trace (see example below)
|
||||||
|
split_callback (Callable[[torch.fx.node.Node], int]): Callable function
|
||||||
|
that maps a given Node instance to a numeric partition identifier.
|
||||||
|
split_module will use this function as the policy for which operations
|
||||||
|
appear in which partitions in the output Module.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GraphModule: the module after split.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
This is a sample setup:
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.fx.symbolic_trace import symbolic_trace
|
||||||
|
from torch.fx.graph_module import GraphModule
|
||||||
|
from torch.fx.node import Node
|
||||||
|
from colossalai.fx.passes.split_module import split_module
|
||||||
|
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
||||||
|
self.linear = torch.nn.Linear(4, 5)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
|
||||||
|
w = self.linear(y).clamp(min=0.0, max=1.0)
|
||||||
|
return z + w
|
||||||
|
|
||||||
|
# symbolically trace model
|
||||||
|
my_module = MyModule()
|
||||||
|
my_module_traced = symbolic_trace(my_module)
|
||||||
|
|
||||||
|
# random mod partitioning
|
||||||
|
partition_counter = 0
|
||||||
|
NPARTITIONS = 3
|
||||||
|
|
||||||
|
def mod_partition(node: Node):
|
||||||
|
global partition_counter
|
||||||
|
partition = partition_counter % NPARTITIONS
|
||||||
|
partition_counter = (partition_counter + 1) % NPARTITIONS
|
||||||
|
return partition
|
||||||
|
|
||||||
|
# split module in module with submodules
|
||||||
|
module_with_submodules = split_module(
|
||||||
|
my_module_traced, my_module, mod_partition
|
||||||
|
)
|
||||||
|
|
||||||
|
Output looks like this. Original graph is broken into partitions
|
||||||
|
|
||||||
|
> print(module_with_submodules)
|
||||||
|
GraphModule(
|
||||||
|
(submod_0): GraphModule(
|
||||||
|
(linear): Linear(in_features=4, out_features=5, bias=True)
|
||||||
|
)
|
||||||
|
(submod_1): GraphModule(
|
||||||
|
(linear): Linear(in_features=4, out_features=5, bias=True)
|
||||||
|
)
|
||||||
|
(submod_2): GraphModule()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
param = self.param
|
||||||
|
submod_0 = self.submod_0(x, param, y); x = param = y = None
|
||||||
|
getitem = submod_0[0]
|
||||||
|
getitem_1 = submod_0[1]; submod_0 = None
|
||||||
|
submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None
|
||||||
|
getitem_2 = submod_1[0]
|
||||||
|
getitem_3 = submod_1[1]; submod_1 = None
|
||||||
|
submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None
|
||||||
|
return submod_2
|
||||||
|
|
||||||
|
Output of split module is the same as output of input traced module.
|
||||||
|
This is an example within a test setting:
|
||||||
|
|
||||||
|
> orig_out = my_module_traced(x, y)
|
||||||
|
> submodules_out = module_with_submodules(x, y)
|
||||||
|
> self.assertEqual(orig_out, submodules_out)
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
partitions: Dict[str, Partition] = {}
|
||||||
|
orig_nodes: Dict[str, torch.fx.node.Node] = {}
|
||||||
|
|
||||||
|
def record_cross_partition_use(def_node: torch.fx.node.Node,
|
||||||
|
use_node: Optional[torch.fx.node.Node]): # noqa: B950
|
||||||
|
def_partition_name = getattr(def_node, '_fx_partition', None)
|
||||||
|
use_partition_name = getattr(use_node, '_fx_partition', None)
|
||||||
|
if def_partition_name != use_partition_name:
|
||||||
|
if def_partition_name is not None:
|
||||||
|
def_partition = partitions[def_partition_name]
|
||||||
|
def_partition.outputs.setdefault(def_node.name)
|
||||||
|
if use_partition_name is not None:
|
||||||
|
def_partition.partition_dependents.setdefault(use_partition_name)
|
||||||
|
|
||||||
|
if use_partition_name is not None:
|
||||||
|
use_partition = partitions[use_partition_name]
|
||||||
|
use_partition.inputs.setdefault(def_node.name)
|
||||||
|
if def_partition_name is not None:
|
||||||
|
use_partition.partitions_dependent_on.setdefault(def_partition_name)
|
||||||
|
|
||||||
|
# split nodes into parititons
|
||||||
|
for node in m.graph.nodes:
|
||||||
|
orig_nodes[node.name] = node
|
||||||
|
|
||||||
|
if node.op in ["placeholder"]:
|
||||||
|
continue
|
||||||
|
if node.op == 'output':
|
||||||
|
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None))
|
||||||
|
continue
|
||||||
|
partition_name = str(split_callback(node))
|
||||||
|
|
||||||
|
# add node to partitions
|
||||||
|
partition = partitions.get(partition_name)
|
||||||
|
if partition is None:
|
||||||
|
partitions[partition_name] = partition = Partition(partition_name)
|
||||||
|
|
||||||
|
partition.node_names.append(node.name)
|
||||||
|
node._fx_partition = partition_name
|
||||||
|
|
||||||
|
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
|
||||||
|
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
|
||||||
|
|
||||||
|
# find partitions with no dependencies
|
||||||
|
root_partitions: List[str] = []
|
||||||
|
for partition_name, partition in partitions.items():
|
||||||
|
if not len(partition.partitions_dependent_on):
|
||||||
|
root_partitions.append(partition_name)
|
||||||
|
|
||||||
|
# check partitions for circular dependencies and create topological partition ordering
|
||||||
|
sorted_partitions: List[str] = []
|
||||||
|
while root_partitions:
|
||||||
|
root_partition = root_partitions.pop()
|
||||||
|
sorted_partitions.append(root_partition)
|
||||||
|
for dependent in partitions[root_partition].partition_dependents:
|
||||||
|
partitions[dependent].partitions_dependent_on.pop(root_partition)
|
||||||
|
if not partitions[dependent].partitions_dependent_on:
|
||||||
|
root_partitions.append(dependent)
|
||||||
|
if len(sorted_partitions) != len(partitions):
|
||||||
|
raise RuntimeError("cycle exists between partitions!")
|
||||||
|
|
||||||
|
# add placeholders to parititons
|
||||||
|
for partition_name in sorted_partitions:
|
||||||
|
partition = partitions[partition_name]
|
||||||
|
for input in partition.inputs:
|
||||||
|
placeholder = partition.graph.placeholder(input)
|
||||||
|
placeholder.meta = orig_nodes[input].meta.copy()
|
||||||
|
partition.environment[orig_nodes[input]] = placeholder
|
||||||
|
|
||||||
|
# Transform nodes and collect targets for partition's submodule
|
||||||
|
for node in m.graph.nodes:
|
||||||
|
if hasattr(node, '_fx_partition'):
|
||||||
|
partition = partitions[node._fx_partition]
|
||||||
|
|
||||||
|
# swap out old graph nodes in kw/args with references to new nodes in this submodule
|
||||||
|
environment = partition.environment
|
||||||
|
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
|
||||||
|
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
|
||||||
|
|
||||||
|
if node.op not in ['call_module', 'get_attr']:
|
||||||
|
target = node.target
|
||||||
|
else:
|
||||||
|
target_atoms = node.target.split('.')
|
||||||
|
target_attr = m
|
||||||
|
for atom in target_atoms:
|
||||||
|
if not hasattr(target_attr, atom):
|
||||||
|
raise RuntimeError(f'Operator target {node.target} not found!')
|
||||||
|
target_attr = getattr(target_attr, atom)
|
||||||
|
# target = target_atoms[-1]
|
||||||
|
target = '_'.join(target_atoms)
|
||||||
|
partition.targets[target] = target_attr
|
||||||
|
|
||||||
|
assert isinstance(gathered_args, tuple)
|
||||||
|
assert isinstance(gathered_kwargs, dict)
|
||||||
|
new_node = partition.graph.create_node(op=node.op,
|
||||||
|
target=target,
|
||||||
|
args=gathered_args,
|
||||||
|
kwargs=gathered_kwargs)
|
||||||
|
new_node.meta = node.meta.copy()
|
||||||
|
partition.environment[node] = new_node
|
||||||
|
|
||||||
|
# Set up values to construct base module
|
||||||
|
base_mod_env: Dict[str, torch.fx.node.Node] = {}
|
||||||
|
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
|
||||||
|
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
|
||||||
|
for node in m.graph.nodes:
|
||||||
|
if node.op == 'placeholder':
|
||||||
|
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
|
||||||
|
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
|
||||||
|
type_expr=node.type,
|
||||||
|
default_value=default_value)
|
||||||
|
base_mod_env[node.name].meta = node.meta.copy()
|
||||||
|
|
||||||
|
# Do some things iterating over the partitions in topological order again:
|
||||||
|
# 1) Finish off submodule Graphs by setting corresponding outputs
|
||||||
|
# 2) Construct GraphModules for each submodule
|
||||||
|
# 3) Construct the base graph by emitting calls to those submodules in
|
||||||
|
# topological order
|
||||||
|
|
||||||
|
for partition_name in sorted_partitions:
|
||||||
|
partition = partitions[partition_name]
|
||||||
|
|
||||||
|
# Set correct output values
|
||||||
|
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
|
||||||
|
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
|
||||||
|
partition.graph.output(output_vals)
|
||||||
|
|
||||||
|
# Construct GraphModule for this partition
|
||||||
|
submod_name = f'submod_{partition_name}'
|
||||||
|
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
|
||||||
|
partition.graph) # noqa: B950
|
||||||
|
|
||||||
|
# Emit call in base graph to this submodule
|
||||||
|
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
|
||||||
|
if len(partition.outputs) > 1:
|
||||||
|
# Unpack multiple return values from submodule
|
||||||
|
output_val_proxy = torch.fx.proxy.Proxy(output_val)
|
||||||
|
for i, output_name in enumerate(partition.outputs):
|
||||||
|
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
|
||||||
|
else:
|
||||||
|
if not partition.outputs:
|
||||||
|
continue
|
||||||
|
base_mod_env[list(partition.outputs)[0]] = output_val
|
||||||
|
|
||||||
|
for node in m.graph.nodes:
|
||||||
|
if node.op == 'output':
|
||||||
|
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
|
||||||
|
|
||||||
|
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
|
|
@ -0,0 +1,69 @@
|
||||||
|
import torch
|
||||||
|
from torch.fx import symbolic_trace
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||||
|
from colossalai.fx import ColoTracer
|
||||||
|
import inspect
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
MANUAL_SEED = 0
|
||||||
|
random.seed(MANUAL_SEED)
|
||||||
|
np.random.seed(MANUAL_SEED)
|
||||||
|
torch.manual_seed(MANUAL_SEED)
|
||||||
|
|
||||||
|
|
||||||
|
def split_model_and_compare_output(model, data_gen):
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# generate input sample
|
||||||
|
kwargs = data_gen()
|
||||||
|
|
||||||
|
# get origin output and rng state
|
||||||
|
cpu_rng_state = torch.get_rng_state()
|
||||||
|
output = model(**kwargs)
|
||||||
|
|
||||||
|
# tracing model
|
||||||
|
tracer = ColoTracer()
|
||||||
|
try:
|
||||||
|
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||||
|
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
# apply transform passes
|
||||||
|
annotated_model = balanced_split_pass(gm, 2)
|
||||||
|
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
|
||||||
|
|
||||||
|
# get split model
|
||||||
|
model_part0 = list(split_model.children())[0]
|
||||||
|
model_part1 = list(split_model.children())[1]
|
||||||
|
|
||||||
|
# set rng state and compute output of split model
|
||||||
|
torch.set_rng_state(cpu_rng_state)
|
||||||
|
output_part0 = model_part0(**kwargs)
|
||||||
|
sig = inspect.signature(model_part1.forward)
|
||||||
|
if isinstance(output_part0, torch.Tensor):
|
||||||
|
output_part1 = model_part1(output_part0)
|
||||||
|
else:
|
||||||
|
if len(output_part0) > len(sig.parameters):
|
||||||
|
output_part0 = output_part0[:len(sig.parameters)]
|
||||||
|
output_part1 = model_part1(*output_part0)
|
||||||
|
|
||||||
|
# get output tensor from HFOutput datastructure
|
||||||
|
if 'logits' in output:
|
||||||
|
output_to_compare = output['logits']
|
||||||
|
elif 'prediction_logits' in output:
|
||||||
|
output_to_compare = output['prediction_logits']
|
||||||
|
else:
|
||||||
|
output_to_compare = output['last_hidden_state']
|
||||||
|
|
||||||
|
# compare output
|
||||||
|
if isinstance(output_part1, torch.Tensor):
|
||||||
|
assert output_to_compare.equal(output_part1)
|
||||||
|
elif isinstance(output_part1, (tuple, list)):
|
||||||
|
assert output_to_compare.equal(output_part1[0])
|
||||||
|
else:
|
||||||
|
assert False
|
|
@ -0,0 +1,38 @@
|
||||||
|
import transformers
|
||||||
|
import torch
|
||||||
|
from hf_utils import split_model_and_compare_output
|
||||||
|
|
||||||
|
BATCH_SIZE = 2
|
||||||
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_sentence_albert():
|
||||||
|
MODEL_LIST = [
|
||||||
|
transformers.AlbertModel,
|
||||||
|
transformers.AlbertForPreTraining,
|
||||||
|
transformers.AlbertForMaskedLM,
|
||||||
|
transformers.AlbertForSequenceClassification,
|
||||||
|
transformers.AlbertForTokenClassification,
|
||||||
|
]
|
||||||
|
|
||||||
|
config = transformers.AlbertConfig(vocab_size=100,
|
||||||
|
embedding_size=128,
|
||||||
|
hidden_size=128,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=256)
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||||
|
return meta_args
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
model = model_cls(config=config)
|
||||||
|
split_model_and_compare_output(model, data_gen)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_single_sentence_albert()
|
|
@ -0,0 +1,38 @@
|
||||||
|
import transformers
|
||||||
|
import torch
|
||||||
|
from hf_utils import split_model_and_compare_output
|
||||||
|
|
||||||
|
BATCH_SIZE = 2
|
||||||
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_sentence_bert():
|
||||||
|
MODEL_LIST = [
|
||||||
|
transformers.BertModel,
|
||||||
|
transformers.BertForPreTraining,
|
||||||
|
transformers.BertLMHeadModel,
|
||||||
|
transformers.BertForMaskedLM,
|
||||||
|
transformers.BertForSequenceClassification,
|
||||||
|
transformers.BertForTokenClassification,
|
||||||
|
]
|
||||||
|
|
||||||
|
config = transformers.BertConfig(vocab_size=100,
|
||||||
|
hidden_size=128,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=256)
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||||
|
return meta_args
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
model = model_cls(config=config)
|
||||||
|
split_model_and_compare_output(model, data_gen)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_single_sentence_bert()
|
|
@ -0,0 +1,34 @@
|
||||||
|
import transformers
|
||||||
|
import torch
|
||||||
|
from hf_utils import split_model_and_compare_output
|
||||||
|
|
||||||
|
BATCH_SIZE = 64
|
||||||
|
SEQ_LENGHT = 16
|
||||||
|
NUM_EPOCHS = 2
|
||||||
|
NUM_CHUNKS = 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt():
|
||||||
|
MODEL_LIST = [
|
||||||
|
transformers.GPT2Model,
|
||||||
|
transformers.GPT2LMHeadModel,
|
||||||
|
transformers.GPT2DoubleHeadsModel,
|
||||||
|
transformers.GPT2ForTokenClassification,
|
||||||
|
# transformers.GPT2ForSequenceClassification, # not supported yet
|
||||||
|
]
|
||||||
|
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=8)
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
model = model_cls(config=config)
|
||||||
|
split_model_and_compare_output(model, data_gen)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_gpt()
|
|
@ -0,0 +1,30 @@
|
||||||
|
import pytest
|
||||||
|
import transformers
|
||||||
|
import torch
|
||||||
|
from hf_utils import split_model_and_compare_output
|
||||||
|
|
||||||
|
BATCH_SIZE = 1
|
||||||
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
|
def test_opt():
|
||||||
|
MODEL_LIST = [
|
||||||
|
transformers.OPTModel,
|
||||||
|
transformers.OPTForCausalLM,
|
||||||
|
]
|
||||||
|
|
||||||
|
config = transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4)
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
model = model_cls(config=config)
|
||||||
|
split_model_and_compare_output(model, data_gen)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_opt()
|
|
@ -0,0 +1,43 @@
|
||||||
|
import pytest
|
||||||
|
import transformers
|
||||||
|
import torch
|
||||||
|
from hf_utils import split_model_and_compare_output
|
||||||
|
|
||||||
|
BATCH_SIZE = 1
|
||||||
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip('tracing failed')
|
||||||
|
def test_t5():
|
||||||
|
MODEL_LIST = [
|
||||||
|
transformers.T5Model,
|
||||||
|
transformers.T5ForConditionalGeneration,
|
||||||
|
transformers.T5EncoderModel,
|
||||||
|
]
|
||||||
|
|
||||||
|
config = transformers.T5Config(d_model=128, num_layers=2)
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def data_gen_for_encoder_only():
|
||||||
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
kwargs = dict(input_ids=input_ids)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
model = model_cls(config=config)
|
||||||
|
|
||||||
|
if isinstance(model, transformers.T5EncoderModel):
|
||||||
|
data_gen_func = data_gen_for_encoder_only
|
||||||
|
else:
|
||||||
|
data_gen_func = data_gen
|
||||||
|
|
||||||
|
split_model_and_compare_output(model, data_gen_func)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_t5()
|
|
@ -0,0 +1,51 @@
|
||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
try:
|
||||||
|
import timm.models as tm
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
from timm_utils import split_model_and_compare_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip('skip as timm is required')
|
||||||
|
def test_timm_models_without_control_flow():
|
||||||
|
|
||||||
|
MODEL_LIST = [
|
||||||
|
tm.resnest.resnest50d,
|
||||||
|
tm.beit.beit_base_patch16_224,
|
||||||
|
tm.cait.cait_s24_224,
|
||||||
|
tm.convmixer.convmixer_768_32,
|
||||||
|
tm.efficientnet.efficientnetv2_m,
|
||||||
|
tm.resmlp_12_224,
|
||||||
|
tm.vision_transformer.vit_base_patch16_224,
|
||||||
|
tm.deit_base_distilled_patch16_224,
|
||||||
|
]
|
||||||
|
|
||||||
|
data = torch.rand(2, 3, 224, 224)
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
model = model_cls()
|
||||||
|
split_model_and_compare_output(model, data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip('skip as timm is required')
|
||||||
|
def test_timm_models_with_control_flow():
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
MODEL_LIST_WITH_CONTROL_FLOW = [
|
||||||
|
tm.convnext.convnext_base, tm.vgg.vgg11, tm.dpn.dpn68, tm.densenet.densenet121, tm.rexnet.rexnet_100,
|
||||||
|
tm.swin_transformer.swin_base_patch4_window7_224
|
||||||
|
]
|
||||||
|
|
||||||
|
data = torch.rand(2, 3, 224, 224)
|
||||||
|
|
||||||
|
meta_args = {'x': data.to('meta')}
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST_WITH_CONTROL_FLOW:
|
||||||
|
model = model_cls()
|
||||||
|
split_model_and_compare_output(model, data, meta_args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_timm_models_without_control_flow()
|
||||||
|
test_timm_models_with_control_flow()
|
|
@ -0,0 +1,51 @@
|
||||||
|
import torch
|
||||||
|
from torch.fx import symbolic_trace
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||||
|
from colossalai.fx import ColoTracer
|
||||||
|
import inspect
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
MANUAL_SEED = 0
|
||||||
|
random.seed(MANUAL_SEED)
|
||||||
|
np.random.seed(MANUAL_SEED)
|
||||||
|
torch.manual_seed(MANUAL_SEED)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
|
||||||
|
def split_model_and_compare_output(model, data, meta_args=None):
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# get origin output and rng state
|
||||||
|
cpu_rng_state = torch.get_rng_state()
|
||||||
|
output = model(data)
|
||||||
|
|
||||||
|
# tracing model
|
||||||
|
tracer = ColoTracer()
|
||||||
|
try:
|
||||||
|
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
# apply transform passes
|
||||||
|
annotated_model = balanced_split_pass(gm, 2)
|
||||||
|
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
|
||||||
|
|
||||||
|
# get split model
|
||||||
|
model_part0 = list(split_model.children())[0]
|
||||||
|
model_part1 = list(split_model.children())[1]
|
||||||
|
|
||||||
|
# set rng state and compute output of split model
|
||||||
|
torch.set_rng_state(cpu_rng_state)
|
||||||
|
output_part0 = model_part0(data)
|
||||||
|
sig = inspect.signature(model_part1.forward)
|
||||||
|
if isinstance(output_part0, torch.Tensor):
|
||||||
|
output_part1 = model_part1(output_part0)
|
||||||
|
else:
|
||||||
|
if len(output_part0) > len(sig.parameters):
|
||||||
|
output_part0 = output_part0[:len(sig.parameters)]
|
||||||
|
output_part1 = model_part1(*output_part0)
|
||||||
|
assert output.equal(output_part1)
|
|
@ -0,0 +1,62 @@
|
||||||
|
import torch
|
||||||
|
try:
|
||||||
|
import torchvision.models as tm
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
from colossalai.fx import ColoTracer
|
||||||
|
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
MANUAL_SEED = 0
|
||||||
|
random.seed(MANUAL_SEED)
|
||||||
|
np.random.seed(MANUAL_SEED)
|
||||||
|
torch.manual_seed(MANUAL_SEED)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip('skip as torchvision is required')
|
||||||
|
def test_torchvision_models():
|
||||||
|
MODEL_LIST = [
|
||||||
|
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
|
||||||
|
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.efficientnet_b0, tm.mnasnet0_5
|
||||||
|
]
|
||||||
|
|
||||||
|
tracer = ColoTracer()
|
||||||
|
data = torch.rand(2, 3, 224, 224)
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
model = model_cls()
|
||||||
|
model.eval()
|
||||||
|
cpu_rng_state = torch.get_rng_state()
|
||||||
|
output = model(data)
|
||||||
|
graph = tracer.trace(root=model)
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
# apply transform passes
|
||||||
|
annotated_model = balanced_split_pass(gm, 2)
|
||||||
|
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
|
||||||
|
|
||||||
|
# get split model
|
||||||
|
model_part0 = list(split_model.children())[0]
|
||||||
|
model_part1 = list(split_model.children())[1]
|
||||||
|
|
||||||
|
# set rng state and compute output of split model
|
||||||
|
torch.set_rng_state(cpu_rng_state)
|
||||||
|
output_part0 = model_part0(data)
|
||||||
|
sig = inspect.signature(model_part1.forward)
|
||||||
|
if isinstance(output_part0, torch.Tensor):
|
||||||
|
output_part1 = model_part1(output_part0)
|
||||||
|
else:
|
||||||
|
if len(output_part0) > len(sig.parameters):
|
||||||
|
output_part0 = output_part0[:len(sig.parameters)]
|
||||||
|
output_part1 = model_part1(*output_part0)
|
||||||
|
assert output.equal(output_part1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_torchvision_models()
|
Loading…
Reference in New Issue