mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix some bugs during gpt2 testing (#1379)
parent
828b9e5e0d
commit
df54481473
|
@ -72,25 +72,53 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
|
||||||
new_placeholder_list = []
|
new_placeholder_list = []
|
||||||
for node in gm.graph.nodes:
|
for node in gm.graph.nodes:
|
||||||
if node.op == 'output':
|
if node.op == 'output':
|
||||||
output_type = node.args[0].__class__
|
if isinstance(node.args[0], (tuple, list)):
|
||||||
output_args.extend(list(node.args[0]))
|
output_type = node.args[0].__class__
|
||||||
for n in node.args[0]:
|
output_args.extend([n.name for n in node.args[0]])
|
||||||
if next_partition_placeholders and n not in next_partition_placeholders:
|
else:
|
||||||
output_args.remove(n)
|
output_args.append(node.args[0].name)
|
||||||
|
rm_list = []
|
||||||
|
for name in output_args:
|
||||||
|
if next_partition_placeholders and name not in next_partition_placeholders:
|
||||||
|
rm_list.append(name)
|
||||||
|
for name in rm_list:
|
||||||
|
output_args.remove(name)
|
||||||
gm.graph.erase_node(node)
|
gm.graph.erase_node(node)
|
||||||
else:
|
else:
|
||||||
non_output_list.append(node.name)
|
non_output_list.append(node.name)
|
||||||
for node in next_partition_placeholders:
|
|
||||||
if node not in output_args:
|
for name in next_partition_placeholders:
|
||||||
output_args.append(node)
|
if name not in output_args:
|
||||||
for node in output_args:
|
output_args.append(name)
|
||||||
if node.name not in non_output_list:
|
|
||||||
gm.graph.placeholder(node.name)
|
for name in output_args:
|
||||||
|
if name not in non_output_list:
|
||||||
|
gm.graph.placeholder(name)
|
||||||
|
|
||||||
|
# convert name to node for output_args
|
||||||
|
for index, name in enumerate(output_args):
|
||||||
|
for n in gm.graph.nodes:
|
||||||
|
if n.name == name:
|
||||||
|
output_args[index] = n
|
||||||
|
continue
|
||||||
|
|
||||||
|
# reorder the output args to make sure
|
||||||
|
# output args has same order as next partition placeholder
|
||||||
|
reorder_output_args = []
|
||||||
|
if next_partition_placeholders:
|
||||||
|
for name in next_partition_placeholders:
|
||||||
|
for node in output_args:
|
||||||
|
if node.name == name:
|
||||||
|
reorder_output_args.append(node)
|
||||||
|
continue
|
||||||
|
|
||||||
for node in gm.graph.nodes:
|
for node in gm.graph.nodes:
|
||||||
if node.op == 'placeholder':
|
if node.op == 'placeholder':
|
||||||
new_placeholder_list.append(node)
|
new_placeholder_list.append(node.name)
|
||||||
gm.graph.output(output_type(output_args))
|
if output_type is not None:
|
||||||
|
gm.graph.output(output_type(output_args))
|
||||||
|
else:
|
||||||
|
gm.graph.output(output_args)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
return gm, new_placeholder_list
|
return gm, new_placeholder_list
|
||||||
|
|
||||||
|
@ -115,15 +143,6 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
|
||||||
for submodule in submodules:
|
for submodule in submodules:
|
||||||
submodule = eliminate_unused_placeholders(submodule)
|
submodule = eliminate_unused_placeholders(submodule)
|
||||||
placeholder_dict[submodule] = []
|
placeholder_dict[submodule] = []
|
||||||
for node in submodule.graph.nodes:
|
|
||||||
if node.op == 'placeholder':
|
|
||||||
placeholder_dict[submodule].append(node)
|
|
||||||
output_dict = {}
|
|
||||||
for submodule in submodules:
|
|
||||||
output_dict[submodule] = []
|
|
||||||
for node in submodule.graph.nodes:
|
|
||||||
if node.op == 'output':
|
|
||||||
output_dict[submodule].append(node.name)
|
|
||||||
submodules.reverse()
|
submodules.reverse()
|
||||||
for index, submodule in enumerate(submodules):
|
for index, submodule in enumerate(submodules):
|
||||||
if index == 0:
|
if index == 0:
|
||||||
|
@ -297,7 +316,7 @@ def split_module_for_gpt2_test(
|
||||||
name=node.name)
|
name=node.name)
|
||||||
new_node.meta = node.meta.copy()
|
new_node.meta = node.meta.copy()
|
||||||
partition.environment[node] = new_node
|
partition.environment[node] = new_node
|
||||||
assert 'add_85' in orig_nodes
|
|
||||||
# Set up values to construct base module
|
# Set up values to construct base module
|
||||||
base_mod_env: Dict[str, torch.fx.node.Node] = {}
|
base_mod_env: Dict[str, torch.fx.node.Node] = {}
|
||||||
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
|
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
|
||||||
|
|
Loading…
Reference in New Issue