[fx] update split module pass and add customized policy (#1373)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [fx]update split module pass and add customized policy
pull/1377/head
YuliangLiu0306 2022-07-27 13:40:54 +08:00 committed by GitHub
parent be229217ce
commit 52bc2dc271
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 85 additions and 20 deletions

View File

@ -61,6 +61,8 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
for node in mod_graph.nodes:
if pp_size <= 1:
break
if 'pipe_split' in node.name:
continue
accumulate_node_size += node.node_size
if accumulate_node_size >= partition_size:
accumulate_node_size = 0

View File

@ -5,11 +5,45 @@ from torch.fx._compatibility import compatibility
from packaging import version
from colossalai.fx.passes.meta_info_prop import TensorMetadata
import inspect
from typing import List
from colossalai.fx.passes.split_module import Partition
from colossalai.fx.passes.adding_split_node_pass import pipe_split
from colossalai.fx.passes.adding_split_node_pass import pipe_split, balanced_split_pass
from torch.fx.node import Node
def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]):
'''
This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future.
'''
mod_graph = gm.graph
valid_children_size = 0
valid_children = []
for node in mod_graph.nodes:
if node.op == "call_module":
valid_children_size += 1
valid_children.append(node.target)
if valid_children_size < pp_size:
# If valid children is not enough to shard, we will use balanced policy instead of uniform policy.
return balanced_split_pass(gm, pp_size)
accumulate_layer_amount = 0
list_of_part = partition_list
part_index = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
if node.op == "call_module":
if node.target in valid_children:
accumulate_layer_amount += 1
if accumulate_layer_amount == list_of_part[part_index]:
part_index += 1
pp_size -= 1
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile()
return gm
def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule):
'''
This pass will be used in gpt2 test, only a part of changes may be added into
@ -25,21 +59,40 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
gm.recompile()
return gm
def eliminate_unused_outputs(gm, next_partition_placeholders):
def refill_outputs_and_placeholders(gm, next_partition_placeholders):
'''
This method is used to eliminate the outputs in previous partition which is unused in next partition.
In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel.
The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it
to partition 1 and partition 2. However, in single direction linked list, we need to do so.
'''
output_type = None
output_args = []
non_output_list = []
new_placeholder_list = []
for node in gm.graph.nodes:
if node.op == 'output':
output_type = node.args[0].__class__
output_args = list(node.args[0])
output_args.extend(list(node.args[0]))
for n in node.args[0]:
if n.name not in next_partition_placeholders:
if next_partition_placeholders and n not in next_partition_placeholders:
output_args.remove(n)
gm.graph.erase_node(node)
else:
non_output_list.append(node.name)
for node in next_partition_placeholders:
if node not in output_args:
output_args.append(node)
for node in output_args:
if node.name not in non_output_list:
gm.graph.placeholder(node.name)
for node in gm.graph.nodes:
if node.op == 'placeholder':
new_placeholder_list.append(node)
gm.graph.output(output_type(output_args))
gm.recompile()
return gm
return gm, new_placeholder_list
def split_callback(n: torch.fx.Node):
nonlocal part_idx
@ -64,13 +117,22 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
placeholder_dict[submodule] = []
for node in submodule.graph.nodes:
if node.op == 'placeholder':
placeholder_dict[submodule].append(node.name)
placeholder_dict[submodule].append(node)
output_dict = {}
for submodule in submodules:
output_dict[submodule] = []
for node in submodule.graph.nodes:
if node.op == 'output':
output_dict[submodule].append(node.name)
submodules.reverse()
for index, submodule in enumerate(submodules):
if index >= len(submodules) - 1:
break
submodule = eliminate_unused_outputs(submodule, placeholder_dict[submodules[index + 1]])
if index == 0:
placeholder_list = []
else:
placeholder_list = placeholder_dict[submodules[index - 1]]
submodule, placeholder_dict[submodule] = refill_outputs_and_placeholders(submodule, placeholder_list)
submodule.recompile()
split_mod.recompile()
return split_mod, split_submodules
@ -118,7 +180,7 @@ def split_module_for_gpt2_test(
_gen_all_ancestors_set(node)
for n in list(all_ancestors):
if n.op != 'placeholder':
if n.op != 'placeholder' and n._fx_partition > partition_name:
n._fx_partition = partition_name
def record_cross_partition_use(def_node: torch.fx.node.Node,
@ -126,14 +188,14 @@ def split_module_for_gpt2_test(
def_partition_name = getattr(def_node, '_fx_partition', None)
use_partition_name = getattr(use_node, '_fx_partition', None)
if def_partition_name != use_partition_name:
if 'tensor_meta' in def_node.meta:
if not _node_with_all_tensor_element(def_node.meta['tensor_meta']):
_move_all_ancestors_into_partition(use_node, def_partition_name)
node_process_list.extend(use_node.all_input_nodes)
node_process_list.extend(list(use_node.users))
node_process_list.append(use_node)
# if 'tensor_meta' in def_node.meta:
# if not _node_with_all_tensor_element(def_node.meta['tensor_meta']):
# _move_all_ancestors_into_partition(use_node, def_partition_name)
# node_process_list.extend(use_node.all_input_nodes)
# node_process_list.extend(list(use_node.users))
# node_process_list.append(use_node)
return
# return
if def_partition_name is not None:
def_partition = partitions[def_partition_name]
@ -231,10 +293,11 @@ def split_module_for_gpt2_test(
new_node = partition.graph.create_node(op=node.op,
target=target,
args=gathered_args,
kwargs=gathered_kwargs)
kwargs=gathered_kwargs,
name=node.name)
new_node.meta = node.meta.copy()
partition.environment[node] = new_node
assert 'add_85' in orig_nodes
# Set up values to construct base module
base_mod_env: Dict[str, torch.fx.node.Node] = {}
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()