diff --git a/colossalai/fx/passes/split_module.py b/colossalai/fx/passes/split_module.py index 48a76660d..bc257edc8 100644 --- a/colossalai/fx/passes/split_module.py +++ b/colossalai/fx/passes/split_module.py @@ -3,7 +3,6 @@ from torch.fx.graph_module import GraphModule from typing import Callable, List, Dict, Any, Optional from torch.fx._compatibility import compatibility from packaging import version -from colossalai.fx.passes.utils import get_DAG import inspect @@ -294,11 +293,5 @@ def split_module( partition = partitions[partition_name] new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) - - DAG = get_DAG(new_gm) - - for _, submodule in new_gm.named_modules(): - if isinstance(submodule, torch.fx.GraphModule): - setattr(submodule, '_DAG', DAG) return new_gm diff --git a/colossalai/fx/passes/utils.py b/colossalai/fx/passes/utils.py index fda010fd3..bb4f3cd6a 100644 --- a/colossalai/fx/passes/utils.py +++ b/colossalai/fx/passes/utils.py @@ -1,8 +1,7 @@ import torch -from typing import Dict, Set +from typing import Dict from torch.fx.node import Node, map_arg from torch.fx.graph import Graph -from torch.fx.graph_module import GraphModule def get_comm_size(prev_partition, next_partition): """ @@ -171,161 +170,3 @@ def get_node_module(node) -> torch.nn.Module: module = node.graph.owning_module.get_submodule(node.target) return module -def find_def_in_partition(node, partitions, input_partitions=None, direct=False): - # find def in input - if input_partitions is not None: - for placeholder in input_partitions: - if placeholder.name == node.name: - return 'MODEL_INPUT' - - # find direct def - if direct: - for partition in partitions: - if node == partition: - return partition.name - # find def with getitem call - else: - for partition in partitions: - if node in partition.users.keys(): - return partition.name - - print(f'Not found def in partition {node.name}') - return None - -def find_user_in_partition(node, partitions, output_partitions=None, direct=False): - user_partition_names = [] - # find direct user - if direct: - for partition in partitions: - if node == partition: - user_partition_names.append(partition.name) - - # find user with getitem call - else: - for partition in partitions: - if node in partition.args: - user_partition_names.append(partition.name) - - if output_partitions is not None: - output_node = output_partitions[0] - if node.op == output_node.op: - user_partition_names.append('MODEL_OUTPUT') - - if len(user_partition_names) > 0: - return user_partition_names - - print(f'Not found user in partition {node.name}') - return None - -def get_partition_depends(partition, partitions, input_partitions=None, output_partitions=None): - # e.g. Partition2: {input: {Partition0: [sub1_1], Partition1: [sub2_0]}, output:{Output: [sub3_0]}}, - input = {} - output = {} - - for offset, arg in enumerate(partition.args): - def_partition_name = None - if not arg.name.startswith('getitem'): - def_partition_name = find_def_in_partition(arg, partitions, input_partitions, direct=True) - else: - def_partition_name = find_def_in_partition(arg, partitions, input_partitions, direct=False) - if def_partition_name is None: - continue - if def_partition_name not in input: - input[def_partition_name] = [] - input[def_partition_name].append(offset) - - offset = -1 - for user in partition.users.keys(): - user_partition_names = None - if input_partitions is None or not user.name.startswith('getitem'): - user_partition_names = find_user_in_partition(user, partitions, output_partitions, direct=True) - offset = 0 - else: - user_partition_names = find_user_in_partition(user, partitions, output_partitions, direct=False) - offset += 1 - if user_partition_names is None: - continue - for user_partition_name in user_partition_names: - if user_partition_name not in output: - output[user_partition_name] = [] - output[user_partition_name].append(offset) - - return input, output, offset+1 - -# DAG just looks like following case. -# the int in every list represents the offset of the partition's input arg or output arg. -# { -# 'input_partition': { -# 'input_ids': { -# 'input': {}, -# 'output': {'submod_0': [0], 'submod_1': [1]}, -# 'output_len': 0}, -# 'attention_mask': { -# 'input': {}, -# 'output': {'submod_2': [0]}, -# 'output_len': 0}}, -# 'submod_0': { -# 'input': {'MODEL_INPUT': [0]}, -# 'output': {'submod_1': [0], 'submod_2': [0, 1]}, -# 'output_len': 2}, -# 'submod_1': { -# 'input': {'submod_0': [0], 'MODEL_INPUT': [1]}, -# 'output': {'submod_2': [0]}, -# 'output_len': 1}, -# 'submod_2': { -# 'input': {'MODEL_INPUT': [0], 'submod_0': [1, 2]}, -# 'output': {'submod_3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -# 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, -# 22, 23, 24]}, -# 'output_len': 25}, -# 'submod_3': { -# 'input': {'submod_2': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -# 12, 13, 14, 15, 16, 17, 18, 19, 20, -# 21, 22, 23, 24]}, -# 'output': {'MODEL_OUTPUT': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, -# 11, 12, 13, 14, 15, 16, 17, 18, 19, -# 20, 21, 22, 23, 24]}, -# 'output_len': 25}, -# 'output_partition': { -# 'input': {'logits': 'submod_3', 'past_key_values': (('submod_3', 'submod_3'), ('submod_3', 'submod_3'), -# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'), -# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'), -# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'), -# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'), -# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'))}, -# 'output': {}, 'output_len': 0} -# } - -# TODO(jiangziyue) Define a Class for DAG. -def get_DAG(gm: GraphModule): - DAG = {} - input_partitions = [] - partitions = [] - output_partitions = [] - for node in gm.graph.nodes: - if node.op == 'placeholder': - input_partitions.append(node) - elif node.name.startswith('submod_'): - partitions.append(node) - elif node.op == 'output': - output_partitions.append(node) - - for partition in input_partitions: - DAG_node = {'input': {}, 'output': {}, 'output_len': 1} - _, output, _ = get_partition_depends(partition, partitions, None, output_partitions) - DAG_node['output'] = output - if 'input_partition' not in DAG: - DAG['input_partition'] = {} - DAG['input_partition'][partition.name] = DAG_node - - for partition in partitions: - DAG_node = {'input': {}, 'output': {}} - DAG_node['input'], DAG_node['output'], DAG_node['output_len'] = get_partition_depends(partition, partitions, input_partitions, output_partitions) - DAG[partition.name] = DAG_node - - for partition in output_partitions: - DAG_node = {'input': {}, 'output': {}, 'output_len': 0} - DAG_node['input'] = torch.fx.graph.map_arg(partition.args[0], lambda n: find_def_in_partition(n, partitions, input_partitions)) - DAG['output_partition'] = DAG_node - - return DAG \ No newline at end of file diff --git a/colossalai/pipeline/middleware/__init__.py b/colossalai/pipeline/middleware/__init__.py new file mode 100644 index 000000000..79e19f9ea --- /dev/null +++ b/colossalai/pipeline/middleware/__init__.py @@ -0,0 +1,3 @@ +from .topo import Topo, Partition, PartitionOutputVal, PartitionInputVal + +__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal'] \ No newline at end of file diff --git a/colossalai/pipeline/middleware/adaptor/__init__.py b/colossalai/pipeline/middleware/adaptor/__init__.py new file mode 100644 index 000000000..949700a2c --- /dev/null +++ b/colossalai/pipeline/middleware/adaptor/__init__.py @@ -0,0 +1,3 @@ +from .fx import get_topology as get_fx_topology + +__all__ = ['get_fx_topology'] \ No newline at end of file diff --git a/colossalai/pipeline/middleware/adaptor/fx.py b/colossalai/pipeline/middleware/adaptor/fx.py new file mode 100644 index 000000000..4351a6b49 --- /dev/null +++ b/colossalai/pipeline/middleware/adaptor/fx.py @@ -0,0 +1,145 @@ +from torch.fx.graph_module import GraphModule +from colossalai.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo +import torch + +def partition_name_to_id(partition_name, is_input=False, is_output=False): + if is_input: + partition_id = 0 + elif is_output: + partition_id = 1 + else: + prefix = 'submod_' + partition_id = int(partition_name.split(prefix)[-1]) + 2 + return partition_id + +# There are two kinds of def in fx.graph +# 1. non direct_use & non direct_def, which means the output is used by next partition with a temporary mid value. +# e.g. submod1 = call_module(...) +# temporary_val = submod1[0] +# submod2 = call_module(temporary_val, ...) +# 2. direct_use & direct_def, which means the output is used by next partition directly. +# e.g. submod1 = call_module(...) +# submod2 = call_module(submod1, ...) +def find_input_in_partition(node, partitions, input_partitions=None): + p_input_val = None + direct_def = not node.name.startswith('getitem') + # search in input + if direct_def and input_partitions is not None: + partition_id = partition_name_to_id('', is_input=True) + for i, input_node in enumerate(input_partitions): + if input_node == node: + p_input_val = PartitionInputVal(partition_id=partition_id, offset=i) + return p_input_val + # search submod in mid part + if direct_def: + for partition in partitions: + if partition == node: + partition_id = partition_name_to_id(partition.name) + p_input_val = PartitionInputVal(partition_id=partition_id, offset=0) + return p_input_val + # search temporary value in graph + else: + for partition in partitions: + for offset, mid_val in enumerate(partition.users): + if mid_val == node: + partition_id = partition_name_to_id(partition.name) + p_input_val = PartitionInputVal(partition_id=partition_id, offset=offset) + return p_input_val + + return p_input_val + +def find_output_in_partition(node, partitions, output_partitions=None): + p_output_val = PartitionOutputVal() + for user in node.users: + direct_use = not user.name.startswith('getitem') + # user is mid partition + for partition in partitions: + # direct call + if direct_use: + if user == partition: + partition_id = partition_name_to_id(partition.name) + for i, arg in enumerate(partition.args): + if arg == node: + p_output_val.add(partition_id=partition_id, offset=i) + break + # getitem call + else: + if user in partition.args: + partition_id = partition_name_to_id(partition.name) + for i, arg in enumerate(partition.args): + if arg == user: + p_output_val.add(partition_id=partition_id, offset=i) + break + + # user is output + if output_partitions is not None: + output_node = output_partitions[0] + if user.op == output_node.op: + output_keys = {} + partition_id = partition_name_to_id('', is_output=True) + torch.fx.graph.map_arg(output_node.args[0], lambda n: output_keys.setdefault(n)) + for i, arg in enumerate(output_keys): + if arg == node: + p_output_val.add(partition_id=partition_id, offset=i) + break + return p_output_val + +def get_topology(gm: GraphModule): + topo = Topo() + topo_output_partition = Partition() + + input_partitions = [] + partitions = [] + output_partitions = [] + for node in gm.graph.nodes: + if node.op == 'placeholder': + input_partitions.append(node) + elif node.name.startswith('submod_'): + partitions.append(node) + elif node.op == 'output': + output_partitions.append(node) + else: + continue + + # set output for input_partition + topo_input_partition = Partition() + for partition in input_partitions: + cur_node = partition + p_output_val = find_output_in_partition(cur_node, partitions, output_partitions) + topo_input_partition.add_output_val(p_output_val) + topo.set_partitions(partition_id=0, partition=topo_input_partition) + topo.set_input_partition(partition_id=0) + + for i, partition in enumerate(partitions): + topo_mid_partition = Partition() + # set input for submodule + for arg in partition.args: + cur_node = arg + p_input_val = find_input_in_partition(cur_node, partitions, input_partitions) + topo_mid_partition.add_input_val(p_input_val) + # set output for submodule + direct_use = True + for user in partition.users: + if user.name.startswith('getitem'): + direct_use = False + break + if direct_use: + cur_node = partition + p_output_val = find_output_in_partition(cur_node, partitions, output_partitions) + topo_mid_partition.add_output_val(p_output_val) + else: + for user in partition.users: + cur_node = user + p_output_val = find_output_in_partition(cur_node, partitions, output_partitions) + topo_mid_partition.add_output_val(p_output_val) + topo.set_partitions(partition_id=i+2, partition=topo_mid_partition) + + # set input for output_partition + for partition in output_partitions: + topo_output_partition = Partition() + torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val( + find_input_in_partition(n, partitions, input_partitions))) + topo.set_partitions(partition_id=1, partition=topo_output_partition) + topo.set_output_partition(partition_id=1) + + return topo \ No newline at end of file diff --git a/colossalai/pipeline/middleware/topo.py b/colossalai/pipeline/middleware/topo.py new file mode 100644 index 000000000..e9d97b0b7 --- /dev/null +++ b/colossalai/pipeline/middleware/topo.py @@ -0,0 +1,164 @@ +from typing import Dict, List +from dataclasses import dataclass + +# This file includes data structure used by Pipeline Middleware. + +@dataclass +class ValPosition: + partition_id: int + offset: int + + def __str__(self) -> str: + res = f'[partition_id:{self.partition_id},offset:{self.offset}]' + return res + + def __repr__(self) -> str: + return self.__str__() + +class PartitionInputVal(object): + def __init__(self, partition_id, offset) -> None: + # every input from which partition_id and which offset + val_pos = ValPosition(partition_id, offset) + self._from_partition_and_offset: ValPosition = val_pos + + def get(self): + return self._from_partition_and_offset + + def __str__(self) -> str: + res = '' + res += f'<-({self._from_partition_and_offset})' + return res + + def __repr__(self) -> str: + return self.__str__() + +class PartitionOutputVal(object): + def __init__(self) -> None: + # every output to which partition_id and which offset + self._to_partition_and_offset: List[ValPosition] = [] + + def add(self, partition_id, offset): + val_pos = ValPosition(partition_id, offset) + self._to_partition_and_offset.append(val_pos) + + def get(self): + return self._to_partition_and_offset + + def __str__(self) -> str: + res = '' + res += '->(' + for val_pos in self._to_partition_and_offset: + res += f'{val_pos},' + res += ')' + return res + + def __repr__(self) -> str: + return self.__str__() + +class Partition(object): + def __init__(self) -> None: + self._input_vals: List[PartitionInputVal] = [] + self._output_vals: List[PartitionOutputVal] = [] + + def add_input_val(self, input_val: PartitionInputVal): + self._input_vals.append(input_val) + + def add_output_val(self, output_val: PartitionOutputVal): + self._output_vals.append(output_val) + + def get_input_vals(self): + return self._input_vals + + def get_output_vals(self): + return self._output_vals + + def __str__(self) -> str: + res = '' + res += f' input:\n' + res += f' length:{len(self._input_vals)}\n' + for i, input_val in enumerate(self._input_vals): + res += f' offset={i}:{input_val}\n' + + res += f' output:\n' + res += f' length:{len(self._output_vals)}\n' + for i, output_val in enumerate(self._output_vals): + res += f' offset={i}:{output_val}\n' + + return res + + def __repr__(self) -> str: + return self.__str__() + +# This class is a middleware between partition splitter +# and Pipeline Scheduler. It records the graph info about +# partition input/output and provides it to scheduler. +# There are three kinds of partition in Pipeline Middleware Design +# which represents the whole process of a model execution: input-fwd-output +# 1. input_partition: records the input of a model. +# 2. mid_partition: record the splitted forwards execution of a model. +# 3. output_partition: records the output of a model. +# attributes: +# _partitions: include all partitions +# _input_partition_id: the key represents input_partition +# _output_partition_id: the key represents output_partition +class Topo(object): + def __init__(self, input_partition_id=None, output_partition_id=None) -> None: + self._partitions: Dict[int, Partition] = {} + self._input_partition_id = input_partition_id + self._output_partition_id = output_partition_id + + def set_input_partition(self, partition_id: int): + self._input_partition_id = partition_id + + def set_output_partition(self, partition_id: int): + self._output_partition_id = partition_id + + def set_partitions(self, partition_id: int, partition: Partition): + self._partitions[partition_id] = partition + + def get_mid_partitions(self): + res = {} #{partition_id: Partition} + for partition_id, partition in self._partitions.items(): + if self._input_partition_id == partition_id or self._output_partition_id == partition_id: + continue + res[partition_id] = partition + return res + + def get_input_partition(self): + if self._input_partition_id is not None: + return self._partitions[self._input_partition_id] + return None + + def get_output_partition(self): + if self._output_partition_id is not None: + return self._partitions[self._output_partition_id] + return None + + def __str__(self) -> str: + res = '' + if len(self._partitions) == 0: + return 'Empty Topo Graph.' + + input_part = self.get_input_partition() + if input_part is not None: + res += '{\n' + res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}' + res += '}\n' + + mid_parts = self.get_mid_partitions() + for i, (partition_id, part) in enumerate(mid_parts.items()): + res += '{\n' + res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}' + res += '}\n' + + output_part = self.get_output_partition() + if output_part is not None: + res += '{\n' + res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}' + res += '}\n' + + return res + + def __repr__(self) -> str: + return self.__str__() + \ No newline at end of file diff --git a/tests/test_fx/test_pipeline/test_DAG/dag_utils.py b/tests/test_fx/test_pipeline/test_DAG/dag_utils.py deleted file mode 100644 index 104296fb1..000000000 --- a/tests/test_fx/test_pipeline/test_DAG/dag_utils.py +++ /dev/null @@ -1,85 +0,0 @@ -import torch -from torch.fx import GraphModule -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass -from colossalai.fx import ColoTracer -import random -import numpy as np - -MANUAL_SEED = 0 -random.seed(MANUAL_SEED) -np.random.seed(MANUAL_SEED) -torch.manual_seed(MANUAL_SEED) - -def split_model_and_get_DAG(model, data_gen): - model.eval() - - # generate input sample - kwargs = data_gen() - - # get origin output and rng state - cpu_rng_state = torch.get_rng_state() - output = model(**kwargs) - - # tracing model - tracer = ColoTracer() - try: - meta_args = {k: v.to('meta') for k, v in kwargs.items()} - graph = tracer.trace(root=model, meta_args=meta_args) - except Exception as e: - raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() - - # apply transform passes - annotated_model = balanced_split_pass(gm, 2) - top_module, split_submodules = split_with_split_nodes_pass(annotated_model) - - return top_module, split_submodules[0]._DAG - -def check_input(input, input_node, top_module): - for user in input_node.users.keys(): - partition_name = user.name - assert partition_name in input['output'] - -def check_submod(submod_partition, node, top_module): - for arg in node.args: - input_part_name = None - if arg.op == 'placeholder': - input_part_name = 'MODEL_INPUT' - elif not arg.name.startswith('getitem'): - input_part_name = arg.name - else: - input_part_name = arg.args[0].name - assert input_part_name in submod_partition['input'] - - for user in node.users: - output_part_names = [] - if user.op == 'output': - output_part_names.append('MODEL_OUTPUT') - elif not user.name.startswith('getitem'): - output_part_names.append(user.name) - else: - for n in user.users: - if n.op == 'output': - output_part_names.append('MODEL_OUTPUT') - else: - output_part_names.append(n.name) - - for output_part_name in output_part_names: - assert output_part_name in submod_partition['output'] - -def check_DAG(top_module, DAG): - assert 'input_partition' in DAG - input_partition = DAG['input_partition'] - - for node in top_module.graph.nodes: - # check input - if node.op == 'placeholder': - assert node.name in input_partition - input = input_partition[node.name] - check_input(input, node, top_module) - elif node.op == 'call_module': - assert node.name in DAG - submod_partition = DAG[node.name] - check_submod(submod_partition, node, top_module) - \ No newline at end of file diff --git a/tests/test_fx/test_pipeline/test_DAG/test_dag.py b/tests/test_fx/test_pipeline/test_DAG/test_dag.py deleted file mode 100644 index 7f7caa36e..000000000 --- a/tests/test_fx/test_pipeline/test_DAG/test_dag.py +++ /dev/null @@ -1,31 +0,0 @@ -import pytest -import torch -import transformers -from dag_utils import split_model_and_get_DAG, check_DAG - -BATCH_SIZE = 1 -SEQ_LENGHT = 16 - - -@pytest.mark.skip('balance split v2 is not ready') -def test_opt(): - MODEL_LIST = [ - transformers.OPTModel, - #transformers.OPTForCausalLM, - ] - - config = transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4) - - def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) - kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) - return kwargs - - for model_cls in MODEL_LIST: - model = model_cls(config=config) - top_mod, DAG = split_model_and_get_DAG(model, data_gen) - check_DAG(top_mod, DAG) - -if __name__ == '__main__': - test_opt() \ No newline at end of file diff --git a/tests/test_fx/test_pipeline/test_topo/test_topo.py b/tests/test_fx/test_pipeline/test_topo/test_topo.py new file mode 100644 index 000000000..75c748705 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_topo/test_topo.py @@ -0,0 +1,43 @@ +import pytest +import torch +import transformers +from topo_utils import split_model_and_get_DAG, check_topo, MLP + +BATCH_SIZE = 1 +SEQ_LENGHT = 16 + +def test_opt(): + MODEL_LIST = [ + MLP, + transformers.OPTModel, + ] + + CONFIGS = [ + {'dim': 10, 'layers': 12}, + transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4), + ] + + def data_gen_MLP(): + x = torch.zeros((16, 10)) + kwargs = dict(x=x) + return kwargs + + def data_gen_OPT(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + return kwargs + + DATAGEN = [ + data_gen_MLP, + data_gen_OPT, + ] + + for i, model_cls in enumerate(MODEL_LIST): + model = model_cls(config=CONFIGS[i]) + top_mod, topo = split_model_and_get_DAG(model, DATAGEN[i]) + # print(f'{top_mod=}\n----\n{topo=}') + check_topo(top_mod, topo) + +if __name__ == '__main__': + test_opt() \ No newline at end of file diff --git a/tests/test_fx/test_pipeline/test_topo/topo_utils.py b/tests/test_fx/test_pipeline/test_topo/topo_utils.py new file mode 100644 index 000000000..55dd65201 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_topo/topo_utils.py @@ -0,0 +1,92 @@ +import torch +from torch.fx import GraphModule +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass +from colossalai.fx import ColoTracer +from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo +from colossalai.pipeline.middleware.adaptor import get_fx_topology +import random +import numpy as np + +MANUAL_SEED = 0 +random.seed(MANUAL_SEED) +np.random.seed(MANUAL_SEED) +torch.manual_seed(MANUAL_SEED) + +class MLP(torch.nn.Module): + def __init__(self, config={}): + super().__init__() + dim = config['dim'] + layers = config['layers'] + self.layers = torch.nn.ModuleList() + + for _ in range(layers): + self.layers.append(torch.nn.Linear(dim, dim, bias=False)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + +def split_model_and_get_DAG(model, data_gen): + model.eval() + + # generate input sample + kwargs = data_gen() + + # tracing model + tracer = ColoTracer() + try: + meta_args = {k: v.to('meta') for k, v in kwargs.items()} + graph = tracer.trace(root=model, meta_args=meta_args) + except Exception as e: + raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # apply transform passes + annotated_model = balanced_split_pass(gm, 2) + top_module, split_submodules = split_with_split_nodes_pass(annotated_model) + + topo = get_fx_topology(top_module) + for submodule in split_submodules: + if isinstance(submodule, torch.fx.GraphModule): + setattr(submodule, '_topo', topo) + + return top_module, split_submodules[0]._topo + +def check_input(top_module, input_partition: Partition): + partition_output = input_partition.get_output_vals() + arg_pos = 0 + for node in top_module.graph.nodes: + if node.op == 'placeholder': + cur_checkee = partition_output[arg_pos] + to_partition_and_offset = cur_checkee.get() + assert len(to_partition_and_offset) == len(node.users.keys()) + arg_pos += 1 + + assert arg_pos == len(partition_output) + +def check_submod(top_module, part_id, mid_partition: Partition): + partition_input = mid_partition.get_input_vals() + partition_output = mid_partition.get_output_vals() + + cnt = 1 + cur_node = None + for node in top_module.graph.nodes: + if node.name.startswith('submod'): + cnt += 1 + if cnt == part_id: + cur_node = node + break + + assert len(partition_input) == len(cur_node.args) + assert len(partition_output) == len(cur_node.users) + +def check_topo(top_module, topo: Topo): + input_partition = topo.get_input_partition() + mid_partitions = topo.get_mid_partitions() + + check_input(top_module, input_partition) + for part_id, submod in mid_partitions.items(): + check_submod(top_module, part_id, submod) + \ No newline at end of file