diff --git a/colossalai/fx/passes/passes_for_gpt2_test.py b/colossalai/fx/passes/passes_for_gpt2_test.py index 93aa7fb99..f98fcd686 100644 --- a/colossalai/fx/passes/passes_for_gpt2_test.py +++ b/colossalai/fx/passes/passes_for_gpt2_test.py @@ -72,25 +72,53 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule) new_placeholder_list = [] for node in gm.graph.nodes: if node.op == 'output': - output_type = node.args[0].__class__ - output_args.extend(list(node.args[0])) - for n in node.args[0]: - if next_partition_placeholders and n not in next_partition_placeholders: - output_args.remove(n) + if isinstance(node.args[0], (tuple, list)): + output_type = node.args[0].__class__ + output_args.extend([n.name for n in node.args[0]]) + else: + output_args.append(node.args[0].name) + rm_list = [] + for name in output_args: + if next_partition_placeholders and name not in next_partition_placeholders: + rm_list.append(name) + for name in rm_list: + output_args.remove(name) gm.graph.erase_node(node) else: non_output_list.append(node.name) - for node in next_partition_placeholders: - if node not in output_args: - output_args.append(node) - for node in output_args: - if node.name not in non_output_list: - gm.graph.placeholder(node.name) + + for name in next_partition_placeholders: + if name not in output_args: + output_args.append(name) + + for name in output_args: + if name not in non_output_list: + gm.graph.placeholder(name) + + # convert name to node for output_args + for index, name in enumerate(output_args): + for n in gm.graph.nodes: + if n.name == name: + output_args[index] = n + continue + + # reorder the output args to make sure + # output args has same order as next partition placeholder + reorder_output_args = [] + if next_partition_placeholders: + for name in next_partition_placeholders: + for node in output_args: + if node.name == name: + reorder_output_args.append(node) + continue for node in gm.graph.nodes: if node.op == 'placeholder': - new_placeholder_list.append(node) - gm.graph.output(output_type(output_args)) + new_placeholder_list.append(node.name) + if output_type is not None: + gm.graph.output(output_type(output_args)) + else: + gm.graph.output(output_args) gm.recompile() return gm, new_placeholder_list @@ -115,15 +143,6 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule) for submodule in submodules: submodule = eliminate_unused_placeholders(submodule) placeholder_dict[submodule] = [] - for node in submodule.graph.nodes: - if node.op == 'placeholder': - placeholder_dict[submodule].append(node) - output_dict = {} - for submodule in submodules: - output_dict[submodule] = [] - for node in submodule.graph.nodes: - if node.op == 'output': - output_dict[submodule].append(node.name) submodules.reverse() for index, submodule in enumerate(submodules): if index == 0: @@ -297,7 +316,7 @@ def split_module_for_gpt2_test( name=node.name) new_node.meta = node.meta.copy() partition.environment[node] = new_node - assert 'add_85' in orig_nodes + # Set up values to construct base module base_mod_env: Dict[str, torch.fx.node.Node] = {} base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()