|
|
|
@ -72,25 +72,53 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
|
|
|
|
|
new_placeholder_list = [] |
|
|
|
|
for node in gm.graph.nodes: |
|
|
|
|
if node.op == 'output': |
|
|
|
|
output_type = node.args[0].__class__ |
|
|
|
|
output_args.extend(list(node.args[0])) |
|
|
|
|
for n in node.args[0]: |
|
|
|
|
if next_partition_placeholders and n not in next_partition_placeholders: |
|
|
|
|
output_args.remove(n) |
|
|
|
|
if isinstance(node.args[0], (tuple, list)): |
|
|
|
|
output_type = node.args[0].__class__ |
|
|
|
|
output_args.extend([n.name for n in node.args[0]]) |
|
|
|
|
else: |
|
|
|
|
output_args.append(node.args[0].name) |
|
|
|
|
rm_list = [] |
|
|
|
|
for name in output_args: |
|
|
|
|
if next_partition_placeholders and name not in next_partition_placeholders: |
|
|
|
|
rm_list.append(name) |
|
|
|
|
for name in rm_list: |
|
|
|
|
output_args.remove(name) |
|
|
|
|
gm.graph.erase_node(node) |
|
|
|
|
else: |
|
|
|
|
non_output_list.append(node.name) |
|
|
|
|
for node in next_partition_placeholders: |
|
|
|
|
if node not in output_args: |
|
|
|
|
output_args.append(node) |
|
|
|
|
for node in output_args: |
|
|
|
|
if node.name not in non_output_list: |
|
|
|
|
gm.graph.placeholder(node.name) |
|
|
|
|
|
|
|
|
|
for name in next_partition_placeholders: |
|
|
|
|
if name not in output_args: |
|
|
|
|
output_args.append(name) |
|
|
|
|
|
|
|
|
|
for name in output_args: |
|
|
|
|
if name not in non_output_list: |
|
|
|
|
gm.graph.placeholder(name) |
|
|
|
|
|
|
|
|
|
# convert name to node for output_args |
|
|
|
|
for index, name in enumerate(output_args): |
|
|
|
|
for n in gm.graph.nodes: |
|
|
|
|
if n.name == name: |
|
|
|
|
output_args[index] = n |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
# reorder the output args to make sure |
|
|
|
|
# output args has same order as next partition placeholder |
|
|
|
|
reorder_output_args = [] |
|
|
|
|
if next_partition_placeholders: |
|
|
|
|
for name in next_partition_placeholders: |
|
|
|
|
for node in output_args: |
|
|
|
|
if node.name == name: |
|
|
|
|
reorder_output_args.append(node) |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
for node in gm.graph.nodes: |
|
|
|
|
if node.op == 'placeholder': |
|
|
|
|
new_placeholder_list.append(node) |
|
|
|
|
gm.graph.output(output_type(output_args)) |
|
|
|
|
new_placeholder_list.append(node.name) |
|
|
|
|
if output_type is not None: |
|
|
|
|
gm.graph.output(output_type(output_args)) |
|
|
|
|
else: |
|
|
|
|
gm.graph.output(output_args) |
|
|
|
|
gm.recompile() |
|
|
|
|
return gm, new_placeholder_list |
|
|
|
|
|
|
|
|
@ -115,15 +143,6 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
|
|
|
|
|
for submodule in submodules: |
|
|
|
|
submodule = eliminate_unused_placeholders(submodule) |
|
|
|
|
placeholder_dict[submodule] = [] |
|
|
|
|
for node in submodule.graph.nodes: |
|
|
|
|
if node.op == 'placeholder': |
|
|
|
|
placeholder_dict[submodule].append(node) |
|
|
|
|
output_dict = {} |
|
|
|
|
for submodule in submodules: |
|
|
|
|
output_dict[submodule] = [] |
|
|
|
|
for node in submodule.graph.nodes: |
|
|
|
|
if node.op == 'output': |
|
|
|
|
output_dict[submodule].append(node.name) |
|
|
|
|
submodules.reverse() |
|
|
|
|
for index, submodule in enumerate(submodules): |
|
|
|
|
if index == 0: |
|
|
|
@ -297,7 +316,7 @@ def split_module_for_gpt2_test(
|
|
|
|
|
name=node.name) |
|
|
|
|
new_node.meta = node.meta.copy() |
|
|
|
|
partition.environment[node] = new_node |
|
|
|
|
assert 'add_85' in orig_nodes |
|
|
|
|
|
|
|
|
|
# Set up values to construct base module |
|
|
|
|
base_mod_env: Dict[str, torch.fx.node.Node] = {} |
|
|
|
|
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() |
|
|
|
|