mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix some bugs during gpt2 testing (#1379)
parent
828b9e5e0d
commit
df54481473
|
@ -72,25 +72,53 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
|
|||
new_placeholder_list = []
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'output':
|
||||
output_type = node.args[0].__class__
|
||||
output_args.extend(list(node.args[0]))
|
||||
for n in node.args[0]:
|
||||
if next_partition_placeholders and n not in next_partition_placeholders:
|
||||
output_args.remove(n)
|
||||
if isinstance(node.args[0], (tuple, list)):
|
||||
output_type = node.args[0].__class__
|
||||
output_args.extend([n.name for n in node.args[0]])
|
||||
else:
|
||||
output_args.append(node.args[0].name)
|
||||
rm_list = []
|
||||
for name in output_args:
|
||||
if next_partition_placeholders and name not in next_partition_placeholders:
|
||||
rm_list.append(name)
|
||||
for name in rm_list:
|
||||
output_args.remove(name)
|
||||
gm.graph.erase_node(node)
|
||||
else:
|
||||
non_output_list.append(node.name)
|
||||
for node in next_partition_placeholders:
|
||||
if node not in output_args:
|
||||
output_args.append(node)
|
||||
for node in output_args:
|
||||
if node.name not in non_output_list:
|
||||
gm.graph.placeholder(node.name)
|
||||
|
||||
for name in next_partition_placeholders:
|
||||
if name not in output_args:
|
||||
output_args.append(name)
|
||||
|
||||
for name in output_args:
|
||||
if name not in non_output_list:
|
||||
gm.graph.placeholder(name)
|
||||
|
||||
# convert name to node for output_args
|
||||
for index, name in enumerate(output_args):
|
||||
for n in gm.graph.nodes:
|
||||
if n.name == name:
|
||||
output_args[index] = n
|
||||
continue
|
||||
|
||||
# reorder the output args to make sure
|
||||
# output args has same order as next partition placeholder
|
||||
reorder_output_args = []
|
||||
if next_partition_placeholders:
|
||||
for name in next_partition_placeholders:
|
||||
for node in output_args:
|
||||
if node.name == name:
|
||||
reorder_output_args.append(node)
|
||||
continue
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
new_placeholder_list.append(node)
|
||||
gm.graph.output(output_type(output_args))
|
||||
new_placeholder_list.append(node.name)
|
||||
if output_type is not None:
|
||||
gm.graph.output(output_type(output_args))
|
||||
else:
|
||||
gm.graph.output(output_args)
|
||||
gm.recompile()
|
||||
return gm, new_placeholder_list
|
||||
|
||||
|
@ -115,15 +143,6 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
|
|||
for submodule in submodules:
|
||||
submodule = eliminate_unused_placeholders(submodule)
|
||||
placeholder_dict[submodule] = []
|
||||
for node in submodule.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
placeholder_dict[submodule].append(node)
|
||||
output_dict = {}
|
||||
for submodule in submodules:
|
||||
output_dict[submodule] = []
|
||||
for node in submodule.graph.nodes:
|
||||
if node.op == 'output':
|
||||
output_dict[submodule].append(node.name)
|
||||
submodules.reverse()
|
||||
for index, submodule in enumerate(submodules):
|
||||
if index == 0:
|
||||
|
@ -297,7 +316,7 @@ def split_module_for_gpt2_test(
|
|||
name=node.name)
|
||||
new_node.meta = node.meta.copy()
|
||||
partition.environment[node] = new_node
|
||||
assert 'add_85' in orig_nodes
|
||||
|
||||
# Set up values to construct base module
|
||||
base_mod_env: Dict[str, torch.fx.node.Node] = {}
|
||||
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
|
||||
|
|
Loading…
Reference in New Issue