[fx] update split module pass and add customized policy (#1373)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [fx]update split module pass and add customized policy
pull/1377/head
YuliangLiu0306 2022-07-27 13:40:54 +08:00 committed by GitHub
parent be229217ce
commit 52bc2dc271
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 85 additions and 20 deletions

View File

@ -61,6 +61,8 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
for node in mod_graph.nodes: for node in mod_graph.nodes:
if pp_size <= 1: if pp_size <= 1:
break break
if 'pipe_split' in node.name:
continue
accumulate_node_size += node.node_size accumulate_node_size += node.node_size
if accumulate_node_size >= partition_size: if accumulate_node_size >= partition_size:
accumulate_node_size = 0 accumulate_node_size = 0

View File

@ -5,11 +5,45 @@ from torch.fx._compatibility import compatibility
from packaging import version from packaging import version
from colossalai.fx.passes.meta_info_prop import TensorMetadata from colossalai.fx.passes.meta_info_prop import TensorMetadata
import inspect import inspect
from typing import List
from colossalai.fx.passes.split_module import Partition from colossalai.fx.passes.split_module import Partition
from colossalai.fx.passes.adding_split_node_pass import pipe_split from colossalai.fx.passes.adding_split_node_pass import pipe_split, balanced_split_pass
from torch.fx.node import Node from torch.fx.node import Node
def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]):
'''
This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future.
'''
mod_graph = gm.graph
valid_children_size = 0
valid_children = []
for node in mod_graph.nodes:
if node.op == "call_module":
valid_children_size += 1
valid_children.append(node.target)
if valid_children_size < pp_size:
# If valid children is not enough to shard, we will use balanced policy instead of uniform policy.
return balanced_split_pass(gm, pp_size)
accumulate_layer_amount = 0
list_of_part = partition_list
part_index = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
if node.op == "call_module":
if node.target in valid_children:
accumulate_layer_amount += 1
if accumulate_layer_amount == list_of_part[part_index]:
part_index += 1
pp_size -= 1
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile()
return gm
def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule): def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule):
''' '''
This pass will be used in gpt2 test, only a part of changes may be added into This pass will be used in gpt2 test, only a part of changes may be added into
@ -25,21 +59,40 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
gm.recompile() gm.recompile()
return gm return gm
def eliminate_unused_outputs(gm, next_partition_placeholders): def refill_outputs_and_placeholders(gm, next_partition_placeholders):
''' '''
This method is used to eliminate the outputs in previous partition which is unused in next partition. This method is used to eliminate the outputs in previous partition which is unused in next partition.
In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel.
The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it
to partition 1 and partition 2. However, in single direction linked list, we need to do so.
''' '''
output_type = None
output_args = []
non_output_list = []
new_placeholder_list = []
for node in gm.graph.nodes: for node in gm.graph.nodes:
if node.op == 'output': if node.op == 'output':
output_type = node.args[0].__class__ output_type = node.args[0].__class__
output_args = list(node.args[0]) output_args.extend(list(node.args[0]))
for n in node.args[0]: for n in node.args[0]:
if n.name not in next_partition_placeholders: if next_partition_placeholders and n not in next_partition_placeholders:
output_args.remove(n) output_args.remove(n)
gm.graph.erase_node(node) gm.graph.erase_node(node)
else:
non_output_list.append(node.name)
for node in next_partition_placeholders:
if node not in output_args:
output_args.append(node)
for node in output_args:
if node.name not in non_output_list:
gm.graph.placeholder(node.name)
for node in gm.graph.nodes:
if node.op == 'placeholder':
new_placeholder_list.append(node)
gm.graph.output(output_type(output_args)) gm.graph.output(output_type(output_args))
gm.recompile() gm.recompile()
return gm return gm, new_placeholder_list
def split_callback(n: torch.fx.Node): def split_callback(n: torch.fx.Node):
nonlocal part_idx nonlocal part_idx
@ -64,13 +117,22 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
placeholder_dict[submodule] = [] placeholder_dict[submodule] = []
for node in submodule.graph.nodes: for node in submodule.graph.nodes:
if node.op == 'placeholder': if node.op == 'placeholder':
placeholder_dict[submodule].append(node.name) placeholder_dict[submodule].append(node)
output_dict = {}
for submodule in submodules:
output_dict[submodule] = []
for node in submodule.graph.nodes:
if node.op == 'output':
output_dict[submodule].append(node.name)
submodules.reverse()
for index, submodule in enumerate(submodules): for index, submodule in enumerate(submodules):
if index >= len(submodules) - 1: if index == 0:
break placeholder_list = []
submodule = eliminate_unused_outputs(submodule, placeholder_dict[submodules[index + 1]]) else:
placeholder_list = placeholder_dict[submodules[index - 1]]
submodule, placeholder_dict[submodule] = refill_outputs_and_placeholders(submodule, placeholder_list)
submodule.recompile() submodule.recompile()
split_mod.recompile() split_mod.recompile()
return split_mod, split_submodules return split_mod, split_submodules
@ -118,7 +180,7 @@ def split_module_for_gpt2_test(
_gen_all_ancestors_set(node) _gen_all_ancestors_set(node)
for n in list(all_ancestors): for n in list(all_ancestors):
if n.op != 'placeholder': if n.op != 'placeholder' and n._fx_partition > partition_name:
n._fx_partition = partition_name n._fx_partition = partition_name
def record_cross_partition_use(def_node: torch.fx.node.Node, def record_cross_partition_use(def_node: torch.fx.node.Node,
@ -126,14 +188,14 @@ def split_module_for_gpt2_test(
def_partition_name = getattr(def_node, '_fx_partition', None) def_partition_name = getattr(def_node, '_fx_partition', None)
use_partition_name = getattr(use_node, '_fx_partition', None) use_partition_name = getattr(use_node, '_fx_partition', None)
if def_partition_name != use_partition_name: if def_partition_name != use_partition_name:
if 'tensor_meta' in def_node.meta: # if 'tensor_meta' in def_node.meta:
if not _node_with_all_tensor_element(def_node.meta['tensor_meta']): # if not _node_with_all_tensor_element(def_node.meta['tensor_meta']):
_move_all_ancestors_into_partition(use_node, def_partition_name) # _move_all_ancestors_into_partition(use_node, def_partition_name)
node_process_list.extend(use_node.all_input_nodes) # node_process_list.extend(use_node.all_input_nodes)
node_process_list.extend(list(use_node.users)) # node_process_list.extend(list(use_node.users))
node_process_list.append(use_node) # node_process_list.append(use_node)
return # return
if def_partition_name is not None: if def_partition_name is not None:
def_partition = partitions[def_partition_name] def_partition = partitions[def_partition_name]
@ -231,10 +293,11 @@ def split_module_for_gpt2_test(
new_node = partition.graph.create_node(op=node.op, new_node = partition.graph.create_node(op=node.op,
target=target, target=target,
args=gathered_args, args=gathered_args,
kwargs=gathered_kwargs) kwargs=gathered_kwargs,
name=node.name)
new_node.meta = node.meta.copy() new_node.meta = node.meta.copy()
partition.environment[node] = new_node partition.environment[node] = new_node
assert 'add_85' in orig_nodes
# Set up values to construct base module # Set up values to construct base module
base_mod_env: Dict[str, torch.fx.node.Node] = {} base_mod_env: Dict[str, torch.fx.node.Node] = {}
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()