Browse Source

[hotfix] fix some bugs during gpt2 testing (#1379)

pull/1384/head
YuliangLiu0306 2 years ago committed by GitHub
parent
commit
df54481473
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 59
      colossalai/fx/passes/passes_for_gpt2_test.py

59
colossalai/fx/passes/passes_for_gpt2_test.py

@ -72,25 +72,53 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
new_placeholder_list = []
for node in gm.graph.nodes:
if node.op == 'output':
if isinstance(node.args[0], (tuple, list)):
output_type = node.args[0].__class__
output_args.extend(list(node.args[0]))
for n in node.args[0]:
if next_partition_placeholders and n not in next_partition_placeholders:
output_args.remove(n)
output_args.extend([n.name for n in node.args[0]])
else:
output_args.append(node.args[0].name)
rm_list = []
for name in output_args:
if next_partition_placeholders and name not in next_partition_placeholders:
rm_list.append(name)
for name in rm_list:
output_args.remove(name)
gm.graph.erase_node(node)
else:
non_output_list.append(node.name)
for node in next_partition_placeholders:
if node not in output_args:
output_args.append(node)
for name in next_partition_placeholders:
if name not in output_args:
output_args.append(name)
for name in output_args:
if name not in non_output_list:
gm.graph.placeholder(name)
# convert name to node for output_args
for index, name in enumerate(output_args):
for n in gm.graph.nodes:
if n.name == name:
output_args[index] = n
continue
# reorder the output args to make sure
# output args has same order as next partition placeholder
reorder_output_args = []
if next_partition_placeholders:
for name in next_partition_placeholders:
for node in output_args:
if node.name not in non_output_list:
gm.graph.placeholder(node.name)
if node.name == name:
reorder_output_args.append(node)
continue
for node in gm.graph.nodes:
if node.op == 'placeholder':
new_placeholder_list.append(node)
new_placeholder_list.append(node.name)
if output_type is not None:
gm.graph.output(output_type(output_args))
else:
gm.graph.output(output_args)
gm.recompile()
return gm, new_placeholder_list
@ -115,15 +143,6 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
for submodule in submodules:
submodule = eliminate_unused_placeholders(submodule)
placeholder_dict[submodule] = []
for node in submodule.graph.nodes:
if node.op == 'placeholder':
placeholder_dict[submodule].append(node)
output_dict = {}
for submodule in submodules:
output_dict[submodule] = []
for node in submodule.graph.nodes:
if node.op == 'output':
output_dict[submodule].append(node.name)
submodules.reverse()
for index, submodule in enumerate(submodules):
if index == 0:
@ -297,7 +316,7 @@ def split_module_for_gpt2_test(
name=node.name)
new_node.meta = node.meta.copy()
partition.environment[node] = new_node
assert 'add_85' in orig_nodes
# Set up values to construct base module
base_mod_env: Dict[str, torch.fx.node.Node] = {}
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()

Loading…
Cancel
Save