mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
371 lines
16 KiB
371 lines
16 KiB
import torch
|
|
from torch.fx.graph_module import GraphModule
|
|
from typing import Callable, List, Dict, Any, Optional
|
|
from torch.fx._compatibility import compatibility
|
|
from packaging import version
|
|
from colossalai.fx.passes.meta_info_prop import TensorMetadata
|
|
import inspect
|
|
from typing import List
|
|
from colossalai.fx.passes.split_module import Partition
|
|
from colossalai.fx.passes.adding_split_node_pass import pipe_split, balanced_split_pass
|
|
from torch.fx.node import Node
|
|
|
|
|
|
def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]):
|
|
'''
|
|
This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future.
|
|
'''
|
|
mod_graph = gm.graph
|
|
valid_children_size = 0
|
|
valid_children = []
|
|
for node in mod_graph.nodes:
|
|
if node.op == "call_module":
|
|
valid_children_size += 1
|
|
valid_children.append(node.target)
|
|
if valid_children_size < pp_size:
|
|
# If valid children is not enough to shard, we will use balanced policy instead of uniform policy.
|
|
return balanced_split_pass(gm, pp_size)
|
|
accumulate_layer_amount = 0
|
|
list_of_part = partition_list
|
|
part_index = 0
|
|
for node in mod_graph.nodes:
|
|
if pp_size <= 1:
|
|
break
|
|
if node.op == "call_module":
|
|
if node.target in valid_children:
|
|
accumulate_layer_amount += 1
|
|
if accumulate_layer_amount == list_of_part[part_index]:
|
|
part_index += 1
|
|
pp_size -= 1
|
|
with mod_graph.inserting_after(node):
|
|
split_node = mod_graph.create_node('call_function', pipe_split)
|
|
|
|
gm.recompile()
|
|
return gm
|
|
|
|
|
|
def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule):
|
|
'''
|
|
This pass will be used in gpt2 test, only a part of changes may be added into
|
|
split_with_split_nodes_pass, and it will be deprecated in future.
|
|
'''
|
|
part_idx = 0
|
|
|
|
def eliminate_unused_placeholders(gm):
|
|
for node in gm.graph.nodes:
|
|
if node.op == 'placeholder':
|
|
if not len(node.users):
|
|
gm.graph.erase_node(node)
|
|
gm.recompile()
|
|
return gm
|
|
|
|
def refill_outputs_and_placeholders(gm, next_partition_placeholders):
|
|
'''
|
|
This method is used to eliminate the outputs in previous partition which is unused in next partition.
|
|
In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel.
|
|
The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it
|
|
to partition 1 and partition 2. However, in single direction linked list, we need to do so.
|
|
'''
|
|
output_type = None
|
|
output_args = []
|
|
non_output_list = []
|
|
new_placeholder_list = []
|
|
for node in gm.graph.nodes:
|
|
if node.op == 'output':
|
|
if isinstance(node.args[0], (tuple, list)):
|
|
output_type = node.args[0].__class__
|
|
output_args.extend([n.name for n in node.args[0]])
|
|
else:
|
|
output_args.append(node.args[0].name)
|
|
rm_list = []
|
|
for name in output_args:
|
|
if next_partition_placeholders and name not in next_partition_placeholders:
|
|
rm_list.append(name)
|
|
for name in rm_list:
|
|
output_args.remove(name)
|
|
gm.graph.erase_node(node)
|
|
else:
|
|
non_output_list.append(node.name)
|
|
|
|
for name in next_partition_placeholders:
|
|
if name not in output_args:
|
|
output_args.append(name)
|
|
|
|
for name in output_args:
|
|
if name not in non_output_list:
|
|
gm.graph.placeholder(name)
|
|
|
|
# convert name to node for output_args
|
|
for index, name in enumerate(output_args):
|
|
for n in gm.graph.nodes:
|
|
if n.name == name:
|
|
output_args[index] = n
|
|
continue
|
|
|
|
# reorder the output args to make sure
|
|
# output args has same order as next partition placeholder
|
|
reorder_output_args = []
|
|
if next_partition_placeholders:
|
|
for name in next_partition_placeholders:
|
|
for node in output_args:
|
|
if node.name == name:
|
|
reorder_output_args.append(node)
|
|
continue
|
|
|
|
for node in gm.graph.nodes:
|
|
if node.op == 'placeholder':
|
|
new_placeholder_list.append(node.name)
|
|
if output_type is not None:
|
|
gm.graph.output(output_type(output_args))
|
|
else:
|
|
gm.graph.output(output_args)
|
|
gm.recompile()
|
|
return gm, new_placeholder_list
|
|
|
|
def split_callback(n: torch.fx.Node):
|
|
nonlocal part_idx
|
|
if (n.op, n.target) == ('call_function', pipe_split):
|
|
part_idx += 1
|
|
return part_idx
|
|
|
|
split_mod = split_module_for_gpt2_test(annotated_gm, None, split_callback)
|
|
split_submodules = []
|
|
for name, submodule in split_mod.named_modules():
|
|
if isinstance(submodule, torch.fx.GraphModule):
|
|
for node in submodule.graph.nodes:
|
|
if (node.op, node.target) == ('call_function', pipe_split):
|
|
submodule.graph.erase_node(node)
|
|
submodule.recompile()
|
|
split_submodules.append(submodule)
|
|
|
|
submodules = list(split_mod.children())
|
|
placeholder_dict = {}
|
|
for submodule in submodules:
|
|
submodule = eliminate_unused_placeholders(submodule)
|
|
placeholder_dict[submodule] = []
|
|
submodules.reverse()
|
|
for index, submodule in enumerate(submodules):
|
|
if index == 0:
|
|
placeholder_list = []
|
|
else:
|
|
placeholder_list = placeholder_dict[submodules[index - 1]]
|
|
submodule, placeholder_dict[submodule] = refill_outputs_and_placeholders(submodule, placeholder_list)
|
|
submodule.recompile()
|
|
|
|
split_mod.recompile()
|
|
|
|
return split_mod, split_submodules
|
|
|
|
|
|
@compatibility(is_backward_compatible=True)
|
|
def split_module_for_gpt2_test(
|
|
m: GraphModule,
|
|
root_m: torch.nn.Module,
|
|
split_callback: Callable[[torch.fx.node.Node], int],
|
|
):
|
|
"""
|
|
This pass will be used in gpt2 pp performance test, only a part of changes may be added into
|
|
split_module, and it will be deprecated in future.
|
|
"""
|
|
partitions: Dict[str, Partition] = {}
|
|
orig_nodes: Dict[str, torch.fx.node.Node] = {}
|
|
|
|
def _node_with_all_tensor_element(node_metadata: Any) -> int:
|
|
"""
|
|
return whether node contains non-tensor element.
|
|
"""
|
|
all_tensor_node = True
|
|
|
|
if isinstance(node_metadata, TensorMetadata):
|
|
all_tensor_node = node_metadata.is_tensor and all_tensor_node
|
|
elif isinstance(node_metadata, dict):
|
|
value_list = [v for _, v in node_metadata.items()]
|
|
all_tensor_node += _node_with_all_tensor_element(value_list)
|
|
else:
|
|
for element in node_metadata:
|
|
all_tensor_node += _node_with_all_tensor_element(element)
|
|
|
|
return all_tensor_node
|
|
|
|
def _move_all_ancestors_into_partition(node, partition_name):
|
|
all_ancestors = set()
|
|
|
|
def _gen_all_ancestors_set(node):
|
|
all_ancestors.add(node)
|
|
for n in node.all_input_nodes:
|
|
if n in all_ancestors:
|
|
continue
|
|
_gen_all_ancestors_set(n)
|
|
|
|
_gen_all_ancestors_set(node)
|
|
for n in list(all_ancestors):
|
|
if n.op != 'placeholder' and n._fx_partition > partition_name:
|
|
n._fx_partition = partition_name
|
|
|
|
def record_cross_partition_use(def_node: torch.fx.node.Node,
|
|
use_node: Optional[torch.fx.node.Node]): # noqa: B950
|
|
def_partition_name = getattr(def_node, '_fx_partition', None)
|
|
use_partition_name = getattr(use_node, '_fx_partition', None)
|
|
if def_partition_name != use_partition_name:
|
|
# if 'tensor_meta' in def_node.meta:
|
|
# if not _node_with_all_tensor_element(def_node.meta['tensor_meta']):
|
|
# _move_all_ancestors_into_partition(use_node, def_partition_name)
|
|
# node_process_list.extend(use_node.all_input_nodes)
|
|
# node_process_list.extend(list(use_node.users))
|
|
# node_process_list.append(use_node)
|
|
|
|
# return
|
|
|
|
if def_partition_name is not None:
|
|
def_partition = partitions[def_partition_name]
|
|
def_partition.outputs.setdefault(def_node.name)
|
|
if use_partition_name is not None:
|
|
def_partition.partition_dependents.setdefault(use_partition_name)
|
|
|
|
if use_partition_name is not None:
|
|
use_partition = partitions[use_partition_name]
|
|
use_partition.inputs.setdefault(def_node.name)
|
|
if def_partition_name is not None:
|
|
use_partition.partitions_dependent_on.setdefault(def_partition_name)
|
|
|
|
node_process_list = list(m.graph.nodes)
|
|
# split nodes into parititons
|
|
while node_process_list:
|
|
node = node_process_list.pop(0)
|
|
orig_nodes[node.name] = node
|
|
|
|
if node.op in ["placeholder"]:
|
|
continue
|
|
if node.op == 'output':
|
|
# partition_name = str(split_callback(node))
|
|
# def _set_output_args_partition(n, partition_name):
|
|
# n._fx_partition = partition_name
|
|
# torch.fx.graph.map_arg(node.args[0], lambda n: _set_output_args_partition(n, partition_name))
|
|
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None))
|
|
continue
|
|
partition_name = str(split_callback(node))
|
|
|
|
# add node to partitions
|
|
partition = partitions.get(partition_name)
|
|
if partition is None:
|
|
partitions[partition_name] = partition = Partition(partition_name)
|
|
|
|
partition.node_names.append(node.name)
|
|
origin_partition_name = getattr(node, '_fx_partition', None)
|
|
if origin_partition_name is None:
|
|
node._fx_partition = partition_name
|
|
|
|
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
|
|
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
|
|
|
|
# find partitions with no dependencies
|
|
root_partitions: List[str] = []
|
|
for partition_name, partition in partitions.items():
|
|
if not len(partition.partitions_dependent_on):
|
|
root_partitions.append(partition_name)
|
|
|
|
# check partitions for circular dependencies and create topological partition ordering
|
|
sorted_partitions: List[str] = []
|
|
while root_partitions:
|
|
root_partition = root_partitions.pop()
|
|
sorted_partitions.append(root_partition)
|
|
for dependent in partitions[root_partition].partition_dependents:
|
|
partitions[dependent].partitions_dependent_on.pop(root_partition)
|
|
if not partitions[dependent].partitions_dependent_on:
|
|
root_partitions.append(dependent)
|
|
if len(sorted_partitions) != len(partitions):
|
|
raise RuntimeError("cycle exists between partitions!")
|
|
|
|
# add placeholders to parititons
|
|
for partition_name in sorted_partitions:
|
|
partition = partitions[partition_name]
|
|
for input in partition.inputs:
|
|
placeholder = partition.graph.placeholder(input)
|
|
placeholder.meta = orig_nodes[input].meta.copy()
|
|
partition.environment[orig_nodes[input]] = placeholder
|
|
|
|
# Transform nodes and collect targets for partition's submodule
|
|
for node in m.graph.nodes:
|
|
if hasattr(node, '_fx_partition'):
|
|
partition = partitions[node._fx_partition]
|
|
|
|
# swap out old graph nodes in kw/args with references to new nodes in this submodule
|
|
environment = partition.environment
|
|
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
|
|
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
|
|
|
|
if node.op not in ['call_module', 'get_attr']:
|
|
target = node.target
|
|
else:
|
|
target_atoms = node.target.split('.')
|
|
target_attr = m
|
|
for atom in target_atoms:
|
|
if not hasattr(target_attr, atom):
|
|
raise RuntimeError(f'Operator target {node.target} not found!')
|
|
target_attr = getattr(target_attr, atom)
|
|
# target = target_atoms[-1]
|
|
target = '_'.join(target_atoms)
|
|
partition.targets[target] = target_attr
|
|
|
|
assert isinstance(gathered_args, tuple)
|
|
assert isinstance(gathered_kwargs, dict)
|
|
new_node = partition.graph.create_node(op=node.op,
|
|
target=target,
|
|
args=gathered_args,
|
|
kwargs=gathered_kwargs,
|
|
name=node.name)
|
|
new_node.meta = node.meta.copy()
|
|
partition.environment[node] = new_node
|
|
|
|
# Set up values to construct base module
|
|
base_mod_env: Dict[str, torch.fx.node.Node] = {}
|
|
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
|
|
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
|
|
for node in m.graph.nodes:
|
|
if node.op == 'placeholder':
|
|
if version.parse(torch.__version__) < version.parse('1.11.0'):
|
|
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
|
|
else:
|
|
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
|
|
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
|
|
type_expr=node.type,
|
|
default_value=default_value)
|
|
base_mod_env[node.name].meta = node.meta.copy()
|
|
|
|
# Do some things iterating over the partitions in topological order again:
|
|
# 1) Finish off submodule Graphs by setting corresponding outputs
|
|
# 2) Construct GraphModules for each submodule
|
|
# 3) Construct the base graph by emitting calls to those submodules in
|
|
# topological order
|
|
|
|
for partition_name in sorted_partitions:
|
|
partition = partitions[partition_name]
|
|
|
|
# Set correct output values
|
|
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
|
|
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
|
|
partition.graph.output(output_vals)
|
|
|
|
# Construct GraphModule for this partition
|
|
submod_name = f'submod_{partition_name}'
|
|
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
|
|
partition.graph) # noqa: B950
|
|
|
|
# Emit call in base graph to this submodule
|
|
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
|
|
if len(partition.outputs) > 1:
|
|
# Unpack multiple return values from submodule
|
|
output_val_proxy = torch.fx.proxy.Proxy(output_val)
|
|
for i, output_name in enumerate(partition.outputs):
|
|
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
|
|
else:
|
|
if not partition.outputs:
|
|
continue
|
|
base_mod_env[list(partition.outputs)[0]] = output_val
|
|
|
|
for node in m.graph.nodes:
|
|
if node.op == 'output':
|
|
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
|
|
|
|
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
|