mirror of https://github.com/hpcaitech/ColossalAI
[fx] update split module pass and add customized policy (#1373)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [fx]update split module pass and add customized policy
pull/1377/head
parent
be229217ce
commit
52bc2dc271
|
@ -61,6 +61,8 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
|
||||||
for node in mod_graph.nodes:
|
for node in mod_graph.nodes:
|
||||||
if pp_size <= 1:
|
if pp_size <= 1:
|
||||||
break
|
break
|
||||||
|
if 'pipe_split' in node.name:
|
||||||
|
continue
|
||||||
accumulate_node_size += node.node_size
|
accumulate_node_size += node.node_size
|
||||||
if accumulate_node_size >= partition_size:
|
if accumulate_node_size >= partition_size:
|
||||||
accumulate_node_size = 0
|
accumulate_node_size = 0
|
||||||
|
|
|
@ -5,11 +5,45 @@ from torch.fx._compatibility import compatibility
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from colossalai.fx.passes.meta_info_prop import TensorMetadata
|
from colossalai.fx.passes.meta_info_prop import TensorMetadata
|
||||||
import inspect
|
import inspect
|
||||||
|
from typing import List
|
||||||
from colossalai.fx.passes.split_module import Partition
|
from colossalai.fx.passes.split_module import Partition
|
||||||
from colossalai.fx.passes.adding_split_node_pass import pipe_split
|
from colossalai.fx.passes.adding_split_node_pass import pipe_split, balanced_split_pass
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
|
|
||||||
|
def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]):
|
||||||
|
'''
|
||||||
|
This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future.
|
||||||
|
'''
|
||||||
|
mod_graph = gm.graph
|
||||||
|
valid_children_size = 0
|
||||||
|
valid_children = []
|
||||||
|
for node in mod_graph.nodes:
|
||||||
|
if node.op == "call_module":
|
||||||
|
valid_children_size += 1
|
||||||
|
valid_children.append(node.target)
|
||||||
|
if valid_children_size < pp_size:
|
||||||
|
# If valid children is not enough to shard, we will use balanced policy instead of uniform policy.
|
||||||
|
return balanced_split_pass(gm, pp_size)
|
||||||
|
accumulate_layer_amount = 0
|
||||||
|
list_of_part = partition_list
|
||||||
|
part_index = 0
|
||||||
|
for node in mod_graph.nodes:
|
||||||
|
if pp_size <= 1:
|
||||||
|
break
|
||||||
|
if node.op == "call_module":
|
||||||
|
if node.target in valid_children:
|
||||||
|
accumulate_layer_amount += 1
|
||||||
|
if accumulate_layer_amount == list_of_part[part_index]:
|
||||||
|
part_index += 1
|
||||||
|
pp_size -= 1
|
||||||
|
with mod_graph.inserting_after(node):
|
||||||
|
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||||
|
|
||||||
|
gm.recompile()
|
||||||
|
return gm
|
||||||
|
|
||||||
|
|
||||||
def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule):
|
def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule):
|
||||||
'''
|
'''
|
||||||
This pass will be used in gpt2 test, only a part of changes may be added into
|
This pass will be used in gpt2 test, only a part of changes may be added into
|
||||||
|
@ -25,21 +59,40 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
def eliminate_unused_outputs(gm, next_partition_placeholders):
|
def refill_outputs_and_placeholders(gm, next_partition_placeholders):
|
||||||
'''
|
'''
|
||||||
This method is used to eliminate the outputs in previous partition which is unused in next partition.
|
This method is used to eliminate the outputs in previous partition which is unused in next partition.
|
||||||
|
In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel.
|
||||||
|
The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it
|
||||||
|
to partition 1 and partition 2. However, in single direction linked list, we need to do so.
|
||||||
'''
|
'''
|
||||||
|
output_type = None
|
||||||
|
output_args = []
|
||||||
|
non_output_list = []
|
||||||
|
new_placeholder_list = []
|
||||||
for node in gm.graph.nodes:
|
for node in gm.graph.nodes:
|
||||||
if node.op == 'output':
|
if node.op == 'output':
|
||||||
output_type = node.args[0].__class__
|
output_type = node.args[0].__class__
|
||||||
output_args = list(node.args[0])
|
output_args.extend(list(node.args[0]))
|
||||||
for n in node.args[0]:
|
for n in node.args[0]:
|
||||||
if n.name not in next_partition_placeholders:
|
if next_partition_placeholders and n not in next_partition_placeholders:
|
||||||
output_args.remove(n)
|
output_args.remove(n)
|
||||||
gm.graph.erase_node(node)
|
gm.graph.erase_node(node)
|
||||||
|
else:
|
||||||
|
non_output_list.append(node.name)
|
||||||
|
for node in next_partition_placeholders:
|
||||||
|
if node not in output_args:
|
||||||
|
output_args.append(node)
|
||||||
|
for node in output_args:
|
||||||
|
if node.name not in non_output_list:
|
||||||
|
gm.graph.placeholder(node.name)
|
||||||
|
|
||||||
|
for node in gm.graph.nodes:
|
||||||
|
if node.op == 'placeholder':
|
||||||
|
new_placeholder_list.append(node)
|
||||||
gm.graph.output(output_type(output_args))
|
gm.graph.output(output_type(output_args))
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
return gm
|
return gm, new_placeholder_list
|
||||||
|
|
||||||
def split_callback(n: torch.fx.Node):
|
def split_callback(n: torch.fx.Node):
|
||||||
nonlocal part_idx
|
nonlocal part_idx
|
||||||
|
@ -64,13 +117,22 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
|
||||||
placeholder_dict[submodule] = []
|
placeholder_dict[submodule] = []
|
||||||
for node in submodule.graph.nodes:
|
for node in submodule.graph.nodes:
|
||||||
if node.op == 'placeholder':
|
if node.op == 'placeholder':
|
||||||
placeholder_dict[submodule].append(node.name)
|
placeholder_dict[submodule].append(node)
|
||||||
|
output_dict = {}
|
||||||
|
for submodule in submodules:
|
||||||
|
output_dict[submodule] = []
|
||||||
|
for node in submodule.graph.nodes:
|
||||||
|
if node.op == 'output':
|
||||||
|
output_dict[submodule].append(node.name)
|
||||||
|
submodules.reverse()
|
||||||
for index, submodule in enumerate(submodules):
|
for index, submodule in enumerate(submodules):
|
||||||
if index >= len(submodules) - 1:
|
if index == 0:
|
||||||
break
|
placeholder_list = []
|
||||||
submodule = eliminate_unused_outputs(submodule, placeholder_dict[submodules[index + 1]])
|
else:
|
||||||
|
placeholder_list = placeholder_dict[submodules[index - 1]]
|
||||||
|
submodule, placeholder_dict[submodule] = refill_outputs_and_placeholders(submodule, placeholder_list)
|
||||||
submodule.recompile()
|
submodule.recompile()
|
||||||
|
|
||||||
split_mod.recompile()
|
split_mod.recompile()
|
||||||
|
|
||||||
return split_mod, split_submodules
|
return split_mod, split_submodules
|
||||||
|
@ -118,7 +180,7 @@ def split_module_for_gpt2_test(
|
||||||
|
|
||||||
_gen_all_ancestors_set(node)
|
_gen_all_ancestors_set(node)
|
||||||
for n in list(all_ancestors):
|
for n in list(all_ancestors):
|
||||||
if n.op != 'placeholder':
|
if n.op != 'placeholder' and n._fx_partition > partition_name:
|
||||||
n._fx_partition = partition_name
|
n._fx_partition = partition_name
|
||||||
|
|
||||||
def record_cross_partition_use(def_node: torch.fx.node.Node,
|
def record_cross_partition_use(def_node: torch.fx.node.Node,
|
||||||
|
@ -126,14 +188,14 @@ def split_module_for_gpt2_test(
|
||||||
def_partition_name = getattr(def_node, '_fx_partition', None)
|
def_partition_name = getattr(def_node, '_fx_partition', None)
|
||||||
use_partition_name = getattr(use_node, '_fx_partition', None)
|
use_partition_name = getattr(use_node, '_fx_partition', None)
|
||||||
if def_partition_name != use_partition_name:
|
if def_partition_name != use_partition_name:
|
||||||
if 'tensor_meta' in def_node.meta:
|
# if 'tensor_meta' in def_node.meta:
|
||||||
if not _node_with_all_tensor_element(def_node.meta['tensor_meta']):
|
# if not _node_with_all_tensor_element(def_node.meta['tensor_meta']):
|
||||||
_move_all_ancestors_into_partition(use_node, def_partition_name)
|
# _move_all_ancestors_into_partition(use_node, def_partition_name)
|
||||||
node_process_list.extend(use_node.all_input_nodes)
|
# node_process_list.extend(use_node.all_input_nodes)
|
||||||
node_process_list.extend(list(use_node.users))
|
# node_process_list.extend(list(use_node.users))
|
||||||
node_process_list.append(use_node)
|
# node_process_list.append(use_node)
|
||||||
|
|
||||||
return
|
# return
|
||||||
|
|
||||||
if def_partition_name is not None:
|
if def_partition_name is not None:
|
||||||
def_partition = partitions[def_partition_name]
|
def_partition = partitions[def_partition_name]
|
||||||
|
@ -231,10 +293,11 @@ def split_module_for_gpt2_test(
|
||||||
new_node = partition.graph.create_node(op=node.op,
|
new_node = partition.graph.create_node(op=node.op,
|
||||||
target=target,
|
target=target,
|
||||||
args=gathered_args,
|
args=gathered_args,
|
||||||
kwargs=gathered_kwargs)
|
kwargs=gathered_kwargs,
|
||||||
|
name=node.name)
|
||||||
new_node.meta = node.meta.copy()
|
new_node.meta = node.meta.copy()
|
||||||
partition.environment[node] = new_node
|
partition.environment[node] = new_node
|
||||||
|
assert 'add_85' in orig_nodes
|
||||||
# Set up values to construct base module
|
# Set up values to construct base module
|
||||||
base_mod_env: Dict[str, torch.fx.node.Node] = {}
|
base_mod_env: Dict[str, torch.fx.node.Node] = {}
|
||||||
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
|
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
|
||||||
|
|
Loading…
Reference in New Issue