[Pipeline] Add Topo Class (#2059)

* use Topo class to rewrite DAG

* polish code

* polish code

* polish code

* add comment

* add else to unended if

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
pull/2071/head
Ziyue Jiang 2022-12-02 18:13:20 +08:00 committed by GitHub
parent e4293e5077
commit 44ea461890
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 451 additions and 283 deletions

View File

@ -3,7 +3,6 @@ from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Any, Optional
from torch.fx._compatibility import compatibility
from packaging import version
from colossalai.fx.passes.utils import get_DAG
import inspect
@ -294,11 +293,5 @@ def split_module(
partition = partitions[partition_name]
new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
DAG = get_DAG(new_gm)
for _, submodule in new_gm.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
setattr(submodule, '_DAG', DAG)
return new_gm

View File

@ -1,8 +1,7 @@
import torch
from typing import Dict, Set
from typing import Dict
from torch.fx.node import Node, map_arg
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
def get_comm_size(prev_partition, next_partition):
"""
@ -171,161 +170,3 @@ def get_node_module(node) -> torch.nn.Module:
module = node.graph.owning_module.get_submodule(node.target)
return module
def find_def_in_partition(node, partitions, input_partitions=None, direct=False):
# find def in input
if input_partitions is not None:
for placeholder in input_partitions:
if placeholder.name == node.name:
return 'MODEL_INPUT'
# find direct def
if direct:
for partition in partitions:
if node == partition:
return partition.name
# find def with getitem call
else:
for partition in partitions:
if node in partition.users.keys():
return partition.name
print(f'Not found def in partition {node.name}')
return None
def find_user_in_partition(node, partitions, output_partitions=None, direct=False):
user_partition_names = []
# find direct user
if direct:
for partition in partitions:
if node == partition:
user_partition_names.append(partition.name)
# find user with getitem call
else:
for partition in partitions:
if node in partition.args:
user_partition_names.append(partition.name)
if output_partitions is not None:
output_node = output_partitions[0]
if node.op == output_node.op:
user_partition_names.append('MODEL_OUTPUT')
if len(user_partition_names) > 0:
return user_partition_names
print(f'Not found user in partition {node.name}')
return None
def get_partition_depends(partition, partitions, input_partitions=None, output_partitions=None):
# e.g. Partition2: {input: {Partition0: [sub1_1], Partition1: [sub2_0]}, output:{Output: [sub3_0]}},
input = {}
output = {}
for offset, arg in enumerate(partition.args):
def_partition_name = None
if not arg.name.startswith('getitem'):
def_partition_name = find_def_in_partition(arg, partitions, input_partitions, direct=True)
else:
def_partition_name = find_def_in_partition(arg, partitions, input_partitions, direct=False)
if def_partition_name is None:
continue
if def_partition_name not in input:
input[def_partition_name] = []
input[def_partition_name].append(offset)
offset = -1
for user in partition.users.keys():
user_partition_names = None
if input_partitions is None or not user.name.startswith('getitem'):
user_partition_names = find_user_in_partition(user, partitions, output_partitions, direct=True)
offset = 0
else:
user_partition_names = find_user_in_partition(user, partitions, output_partitions, direct=False)
offset += 1
if user_partition_names is None:
continue
for user_partition_name in user_partition_names:
if user_partition_name not in output:
output[user_partition_name] = []
output[user_partition_name].append(offset)
return input, output, offset+1
# DAG just looks like following case.
# the int in every list represents the offset of the partition's input arg or output arg.
# {
# 'input_partition': {
# 'input_ids': {
# 'input': {},
# 'output': {'submod_0': [0], 'submod_1': [1]},
# 'output_len': 0},
# 'attention_mask': {
# 'input': {},
# 'output': {'submod_2': [0]},
# 'output_len': 0}},
# 'submod_0': {
# 'input': {'MODEL_INPUT': [0]},
# 'output': {'submod_1': [0], 'submod_2': [0, 1]},
# 'output_len': 2},
# 'submod_1': {
# 'input': {'submod_0': [0], 'MODEL_INPUT': [1]},
# 'output': {'submod_2': [0]},
# 'output_len': 1},
# 'submod_2': {
# 'input': {'MODEL_INPUT': [0], 'submod_0': [1, 2]},
# 'output': {'submod_3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
# 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
# 22, 23, 24]},
# 'output_len': 25},
# 'submod_3': {
# 'input': {'submod_2': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
# 12, 13, 14, 15, 16, 17, 18, 19, 20,
# 21, 22, 23, 24]},
# 'output': {'MODEL_OUTPUT': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
# 11, 12, 13, 14, 15, 16, 17, 18, 19,
# 20, 21, 22, 23, 24]},
# 'output_len': 25},
# 'output_partition': {
# 'input': {'logits': 'submod_3', 'past_key_values': (('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'))},
# 'output': {}, 'output_len': 0}
# }
# TODO(jiangziyue) Define a Class for DAG.
def get_DAG(gm: GraphModule):
DAG = {}
input_partitions = []
partitions = []
output_partitions = []
for node in gm.graph.nodes:
if node.op == 'placeholder':
input_partitions.append(node)
elif node.name.startswith('submod_'):
partitions.append(node)
elif node.op == 'output':
output_partitions.append(node)
for partition in input_partitions:
DAG_node = {'input': {}, 'output': {}, 'output_len': 1}
_, output, _ = get_partition_depends(partition, partitions, None, output_partitions)
DAG_node['output'] = output
if 'input_partition' not in DAG:
DAG['input_partition'] = {}
DAG['input_partition'][partition.name] = DAG_node
for partition in partitions:
DAG_node = {'input': {}, 'output': {}}
DAG_node['input'], DAG_node['output'], DAG_node['output_len'] = get_partition_depends(partition, partitions, input_partitions, output_partitions)
DAG[partition.name] = DAG_node
for partition in output_partitions:
DAG_node = {'input': {}, 'output': {}, 'output_len': 0}
DAG_node['input'] = torch.fx.graph.map_arg(partition.args[0], lambda n: find_def_in_partition(n, partitions, input_partitions))
DAG['output_partition'] = DAG_node
return DAG

View File

@ -0,0 +1,3 @@
from .topo import Topo, Partition, PartitionOutputVal, PartitionInputVal
__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal']

View File

@ -0,0 +1,3 @@
from .fx import get_topology as get_fx_topology
__all__ = ['get_fx_topology']

View File

@ -0,0 +1,145 @@
from torch.fx.graph_module import GraphModule
from colossalai.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo
import torch
def partition_name_to_id(partition_name, is_input=False, is_output=False):
if is_input:
partition_id = 0
elif is_output:
partition_id = 1
else:
prefix = 'submod_'
partition_id = int(partition_name.split(prefix)[-1]) + 2
return partition_id
# There are two kinds of def in fx.graph
# 1. non direct_use & non direct_def, which means the output is used by next partition with a temporary mid value.
# e.g. submod1 = call_module(...)
# temporary_val = submod1[0]
# submod2 = call_module(temporary_val, ...)
# 2. direct_use & direct_def, which means the output is used by next partition directly.
# e.g. submod1 = call_module(...)
# submod2 = call_module(submod1, ...)
def find_input_in_partition(node, partitions, input_partitions=None):
p_input_val = None
direct_def = not node.name.startswith('getitem')
# search in input
if direct_def and input_partitions is not None:
partition_id = partition_name_to_id('', is_input=True)
for i, input_node in enumerate(input_partitions):
if input_node == node:
p_input_val = PartitionInputVal(partition_id=partition_id, offset=i)
return p_input_val
# search submod in mid part
if direct_def:
for partition in partitions:
if partition == node:
partition_id = partition_name_to_id(partition.name)
p_input_val = PartitionInputVal(partition_id=partition_id, offset=0)
return p_input_val
# search temporary value in graph
else:
for partition in partitions:
for offset, mid_val in enumerate(partition.users):
if mid_val == node:
partition_id = partition_name_to_id(partition.name)
p_input_val = PartitionInputVal(partition_id=partition_id, offset=offset)
return p_input_val
return p_input_val
def find_output_in_partition(node, partitions, output_partitions=None):
p_output_val = PartitionOutputVal()
for user in node.users:
direct_use = not user.name.startswith('getitem')
# user is mid partition
for partition in partitions:
# direct call
if direct_use:
if user == partition:
partition_id = partition_name_to_id(partition.name)
for i, arg in enumerate(partition.args):
if arg == node:
p_output_val.add(partition_id=partition_id, offset=i)
break
# getitem call
else:
if user in partition.args:
partition_id = partition_name_to_id(partition.name)
for i, arg in enumerate(partition.args):
if arg == user:
p_output_val.add(partition_id=partition_id, offset=i)
break
# user is output
if output_partitions is not None:
output_node = output_partitions[0]
if user.op == output_node.op:
output_keys = {}
partition_id = partition_name_to_id('', is_output=True)
torch.fx.graph.map_arg(output_node.args[0], lambda n: output_keys.setdefault(n))
for i, arg in enumerate(output_keys):
if arg == node:
p_output_val.add(partition_id=partition_id, offset=i)
break
return p_output_val
def get_topology(gm: GraphModule):
topo = Topo()
topo_output_partition = Partition()
input_partitions = []
partitions = []
output_partitions = []
for node in gm.graph.nodes:
if node.op == 'placeholder':
input_partitions.append(node)
elif node.name.startswith('submod_'):
partitions.append(node)
elif node.op == 'output':
output_partitions.append(node)
else:
continue
# set output for input_partition
topo_input_partition = Partition()
for partition in input_partitions:
cur_node = partition
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
topo_input_partition.add_output_val(p_output_val)
topo.set_partitions(partition_id=0, partition=topo_input_partition)
topo.set_input_partition(partition_id=0)
for i, partition in enumerate(partitions):
topo_mid_partition = Partition()
# set input for submodule
for arg in partition.args:
cur_node = arg
p_input_val = find_input_in_partition(cur_node, partitions, input_partitions)
topo_mid_partition.add_input_val(p_input_val)
# set output for submodule
direct_use = True
for user in partition.users:
if user.name.startswith('getitem'):
direct_use = False
break
if direct_use:
cur_node = partition
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
topo_mid_partition.add_output_val(p_output_val)
else:
for user in partition.users:
cur_node = user
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
topo_mid_partition.add_output_val(p_output_val)
topo.set_partitions(partition_id=i+2, partition=topo_mid_partition)
# set input for output_partition
for partition in output_partitions:
topo_output_partition = Partition()
torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val(
find_input_in_partition(n, partitions, input_partitions)))
topo.set_partitions(partition_id=1, partition=topo_output_partition)
topo.set_output_partition(partition_id=1)
return topo

View File

@ -0,0 +1,164 @@
from typing import Dict, List
from dataclasses import dataclass
# This file includes data structure used by Pipeline Middleware.
@dataclass
class ValPosition:
partition_id: int
offset: int
def __str__(self) -> str:
res = f'[partition_id:{self.partition_id},offset:{self.offset}]'
return res
def __repr__(self) -> str:
return self.__str__()
class PartitionInputVal(object):
def __init__(self, partition_id, offset) -> None:
# every input from which partition_id and which offset
val_pos = ValPosition(partition_id, offset)
self._from_partition_and_offset: ValPosition = val_pos
def get(self):
return self._from_partition_and_offset
def __str__(self) -> str:
res = ''
res += f'<-({self._from_partition_and_offset})'
return res
def __repr__(self) -> str:
return self.__str__()
class PartitionOutputVal(object):
def __init__(self) -> None:
# every output to which partition_id and which offset
self._to_partition_and_offset: List[ValPosition] = []
def add(self, partition_id, offset):
val_pos = ValPosition(partition_id, offset)
self._to_partition_and_offset.append(val_pos)
def get(self):
return self._to_partition_and_offset
def __str__(self) -> str:
res = ''
res += '->('
for val_pos in self._to_partition_and_offset:
res += f'{val_pos},'
res += ')'
return res
def __repr__(self) -> str:
return self.__str__()
class Partition(object):
def __init__(self) -> None:
self._input_vals: List[PartitionInputVal] = []
self._output_vals: List[PartitionOutputVal] = []
def add_input_val(self, input_val: PartitionInputVal):
self._input_vals.append(input_val)
def add_output_val(self, output_val: PartitionOutputVal):
self._output_vals.append(output_val)
def get_input_vals(self):
return self._input_vals
def get_output_vals(self):
return self._output_vals
def __str__(self) -> str:
res = ''
res += f' input:\n'
res += f' length:{len(self._input_vals)}\n'
for i, input_val in enumerate(self._input_vals):
res += f' offset={i}:{input_val}\n'
res += f' output:\n'
res += f' length:{len(self._output_vals)}\n'
for i, output_val in enumerate(self._output_vals):
res += f' offset={i}:{output_val}\n'
return res
def __repr__(self) -> str:
return self.__str__()
# This class is a middleware between partition splitter
# and Pipeline Scheduler. It records the graph info about
# partition input/output and provides it to scheduler.
# There are three kinds of partition in Pipeline Middleware Design
# which represents the whole process of a model execution: input-fwd-output
# 1. input_partition: records the input of a model.
# 2. mid_partition: record the splitted forwards execution of a model.
# 3. output_partition: records the output of a model.
# attributes:
# _partitions: include all partitions
# _input_partition_id: the key represents input_partition
# _output_partition_id: the key represents output_partition
class Topo(object):
def __init__(self, input_partition_id=None, output_partition_id=None) -> None:
self._partitions: Dict[int, Partition] = {}
self._input_partition_id = input_partition_id
self._output_partition_id = output_partition_id
def set_input_partition(self, partition_id: int):
self._input_partition_id = partition_id
def set_output_partition(self, partition_id: int):
self._output_partition_id = partition_id
def set_partitions(self, partition_id: int, partition: Partition):
self._partitions[partition_id] = partition
def get_mid_partitions(self):
res = {} #{partition_id: Partition}
for partition_id, partition in self._partitions.items():
if self._input_partition_id == partition_id or self._output_partition_id == partition_id:
continue
res[partition_id] = partition
return res
def get_input_partition(self):
if self._input_partition_id is not None:
return self._partitions[self._input_partition_id]
return None
def get_output_partition(self):
if self._output_partition_id is not None:
return self._partitions[self._output_partition_id]
return None
def __str__(self) -> str:
res = ''
if len(self._partitions) == 0:
return 'Empty Topo Graph.'
input_part = self.get_input_partition()
if input_part is not None:
res += '{\n'
res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}'
res += '}\n'
mid_parts = self.get_mid_partitions()
for i, (partition_id, part) in enumerate(mid_parts.items()):
res += '{\n'
res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}'
res += '}\n'
output_part = self.get_output_partition()
if output_part is not None:
res += '{\n'
res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}'
res += '}\n'
return res
def __repr__(self) -> str:
return self.__str__()

View File

@ -1,85 +0,0 @@
import torch
from torch.fx import GraphModule
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from colossalai.fx import ColoTracer
import random
import numpy as np
MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
def split_model_and_get_DAG(model, data_gen):
model.eval()
# generate input sample
kwargs = data_gen()
# get origin output and rng state
cpu_rng_state = torch.get_rng_state()
output = model(**kwargs)
# tracing model
tracer = ColoTracer()
try:
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# apply transform passes
annotated_model = balanced_split_pass(gm, 2)
top_module, split_submodules = split_with_split_nodes_pass(annotated_model)
return top_module, split_submodules[0]._DAG
def check_input(input, input_node, top_module):
for user in input_node.users.keys():
partition_name = user.name
assert partition_name in input['output']
def check_submod(submod_partition, node, top_module):
for arg in node.args:
input_part_name = None
if arg.op == 'placeholder':
input_part_name = 'MODEL_INPUT'
elif not arg.name.startswith('getitem'):
input_part_name = arg.name
else:
input_part_name = arg.args[0].name
assert input_part_name in submod_partition['input']
for user in node.users:
output_part_names = []
if user.op == 'output':
output_part_names.append('MODEL_OUTPUT')
elif not user.name.startswith('getitem'):
output_part_names.append(user.name)
else:
for n in user.users:
if n.op == 'output':
output_part_names.append('MODEL_OUTPUT')
else:
output_part_names.append(n.name)
for output_part_name in output_part_names:
assert output_part_name in submod_partition['output']
def check_DAG(top_module, DAG):
assert 'input_partition' in DAG
input_partition = DAG['input_partition']
for node in top_module.graph.nodes:
# check input
if node.op == 'placeholder':
assert node.name in input_partition
input = input_partition[node.name]
check_input(input, node, top_module)
elif node.op == 'call_module':
assert node.name in DAG
submod_partition = DAG[node.name]
check_submod(submod_partition, node, top_module)

View File

@ -1,31 +0,0 @@
import pytest
import torch
import transformers
from dag_utils import split_model_and_get_DAG, check_DAG
BATCH_SIZE = 1
SEQ_LENGHT = 16
@pytest.mark.skip('balance split v2 is not ready')
def test_opt():
MODEL_LIST = [
transformers.OPTModel,
#transformers.OPTForCausalLM,
]
config = transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs
for model_cls in MODEL_LIST:
model = model_cls(config=config)
top_mod, DAG = split_model_and_get_DAG(model, data_gen)
check_DAG(top_mod, DAG)
if __name__ == '__main__':
test_opt()

View File

@ -0,0 +1,43 @@
import pytest
import torch
import transformers
from topo_utils import split_model_and_get_DAG, check_topo, MLP
BATCH_SIZE = 1
SEQ_LENGHT = 16
def test_opt():
MODEL_LIST = [
MLP,
transformers.OPTModel,
]
CONFIGS = [
{'dim': 10, 'layers': 12},
transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4),
]
def data_gen_MLP():
x = torch.zeros((16, 10))
kwargs = dict(x=x)
return kwargs
def data_gen_OPT():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs
DATAGEN = [
data_gen_MLP,
data_gen_OPT,
]
for i, model_cls in enumerate(MODEL_LIST):
model = model_cls(config=CONFIGS[i])
top_mod, topo = split_model_and_get_DAG(model, DATAGEN[i])
# print(f'{top_mod=}\n----\n{topo=}')
check_topo(top_mod, topo)
if __name__ == '__main__':
test_opt()

View File

@ -0,0 +1,92 @@
import torch
from torch.fx import GraphModule
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from colossalai.fx import ColoTracer
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
from colossalai.pipeline.middleware.adaptor import get_fx_topology
import random
import numpy as np
MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
class MLP(torch.nn.Module):
def __init__(self, config={}):
super().__init__()
dim = config['dim']
layers = config['layers']
self.layers = torch.nn.ModuleList()
for _ in range(layers):
self.layers.append(torch.nn.Linear(dim, dim, bias=False))
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def split_model_and_get_DAG(model, data_gen):
model.eval()
# generate input sample
kwargs = data_gen()
# tracing model
tracer = ColoTracer()
try:
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# apply transform passes
annotated_model = balanced_split_pass(gm, 2)
top_module, split_submodules = split_with_split_nodes_pass(annotated_model)
topo = get_fx_topology(top_module)
for submodule in split_submodules:
if isinstance(submodule, torch.fx.GraphModule):
setattr(submodule, '_topo', topo)
return top_module, split_submodules[0]._topo
def check_input(top_module, input_partition: Partition):
partition_output = input_partition.get_output_vals()
arg_pos = 0
for node in top_module.graph.nodes:
if node.op == 'placeholder':
cur_checkee = partition_output[arg_pos]
to_partition_and_offset = cur_checkee.get()
assert len(to_partition_and_offset) == len(node.users.keys())
arg_pos += 1
assert arg_pos == len(partition_output)
def check_submod(top_module, part_id, mid_partition: Partition):
partition_input = mid_partition.get_input_vals()
partition_output = mid_partition.get_output_vals()
cnt = 1
cur_node = None
for node in top_module.graph.nodes:
if node.name.startswith('submod'):
cnt += 1
if cnt == part_id:
cur_node = node
break
assert len(partition_input) == len(cur_node.args)
assert len(partition_output) == len(cur_node.users)
def check_topo(top_module, topo: Topo):
input_partition = topo.get_input_partition()
mid_partitions = topo.get_mid_partitions()
check_input(top_module, input_partition)
for part_id, submod in mid_partitions.items():
check_submod(top_module, part_id, submod)