mirror of https://github.com/hpcaitech/ColossalAI
[Pipeline] Add Topo Class (#2059)
* use Topo class to rewrite DAG * polish code * polish code * polish code * add comment * add else to unended if Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>pull/2071/head
parent
e4293e5077
commit
44ea461890
|
@ -3,7 +3,6 @@ from torch.fx.graph_module import GraphModule
|
|||
from typing import Callable, List, Dict, Any, Optional
|
||||
from torch.fx._compatibility import compatibility
|
||||
from packaging import version
|
||||
from colossalai.fx.passes.utils import get_DAG
|
||||
import inspect
|
||||
|
||||
|
||||
|
@ -294,11 +293,5 @@ def split_module(
|
|||
partition = partitions[partition_name]
|
||||
|
||||
new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
|
||||
|
||||
DAG = get_DAG(new_gm)
|
||||
|
||||
for _, submodule in new_gm.named_modules():
|
||||
if isinstance(submodule, torch.fx.GraphModule):
|
||||
setattr(submodule, '_DAG', DAG)
|
||||
|
||||
return new_gm
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import torch
|
||||
from typing import Dict, Set
|
||||
from typing import Dict
|
||||
from torch.fx.node import Node, map_arg
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.graph_module import GraphModule
|
||||
|
||||
def get_comm_size(prev_partition, next_partition):
|
||||
"""
|
||||
|
@ -171,161 +170,3 @@ def get_node_module(node) -> torch.nn.Module:
|
|||
module = node.graph.owning_module.get_submodule(node.target)
|
||||
return module
|
||||
|
||||
def find_def_in_partition(node, partitions, input_partitions=None, direct=False):
|
||||
# find def in input
|
||||
if input_partitions is not None:
|
||||
for placeholder in input_partitions:
|
||||
if placeholder.name == node.name:
|
||||
return 'MODEL_INPUT'
|
||||
|
||||
# find direct def
|
||||
if direct:
|
||||
for partition in partitions:
|
||||
if node == partition:
|
||||
return partition.name
|
||||
# find def with getitem call
|
||||
else:
|
||||
for partition in partitions:
|
||||
if node in partition.users.keys():
|
||||
return partition.name
|
||||
|
||||
print(f'Not found def in partition {node.name}')
|
||||
return None
|
||||
|
||||
def find_user_in_partition(node, partitions, output_partitions=None, direct=False):
|
||||
user_partition_names = []
|
||||
# find direct user
|
||||
if direct:
|
||||
for partition in partitions:
|
||||
if node == partition:
|
||||
user_partition_names.append(partition.name)
|
||||
|
||||
# find user with getitem call
|
||||
else:
|
||||
for partition in partitions:
|
||||
if node in partition.args:
|
||||
user_partition_names.append(partition.name)
|
||||
|
||||
if output_partitions is not None:
|
||||
output_node = output_partitions[0]
|
||||
if node.op == output_node.op:
|
||||
user_partition_names.append('MODEL_OUTPUT')
|
||||
|
||||
if len(user_partition_names) > 0:
|
||||
return user_partition_names
|
||||
|
||||
print(f'Not found user in partition {node.name}')
|
||||
return None
|
||||
|
||||
def get_partition_depends(partition, partitions, input_partitions=None, output_partitions=None):
|
||||
# e.g. Partition2: {input: {Partition0: [sub1_1], Partition1: [sub2_0]}, output:{Output: [sub3_0]}},
|
||||
input = {}
|
||||
output = {}
|
||||
|
||||
for offset, arg in enumerate(partition.args):
|
||||
def_partition_name = None
|
||||
if not arg.name.startswith('getitem'):
|
||||
def_partition_name = find_def_in_partition(arg, partitions, input_partitions, direct=True)
|
||||
else:
|
||||
def_partition_name = find_def_in_partition(arg, partitions, input_partitions, direct=False)
|
||||
if def_partition_name is None:
|
||||
continue
|
||||
if def_partition_name not in input:
|
||||
input[def_partition_name] = []
|
||||
input[def_partition_name].append(offset)
|
||||
|
||||
offset = -1
|
||||
for user in partition.users.keys():
|
||||
user_partition_names = None
|
||||
if input_partitions is None or not user.name.startswith('getitem'):
|
||||
user_partition_names = find_user_in_partition(user, partitions, output_partitions, direct=True)
|
||||
offset = 0
|
||||
else:
|
||||
user_partition_names = find_user_in_partition(user, partitions, output_partitions, direct=False)
|
||||
offset += 1
|
||||
if user_partition_names is None:
|
||||
continue
|
||||
for user_partition_name in user_partition_names:
|
||||
if user_partition_name not in output:
|
||||
output[user_partition_name] = []
|
||||
output[user_partition_name].append(offset)
|
||||
|
||||
return input, output, offset+1
|
||||
|
||||
# DAG just looks like following case.
|
||||
# the int in every list represents the offset of the partition's input arg or output arg.
|
||||
# {
|
||||
# 'input_partition': {
|
||||
# 'input_ids': {
|
||||
# 'input': {},
|
||||
# 'output': {'submod_0': [0], 'submod_1': [1]},
|
||||
# 'output_len': 0},
|
||||
# 'attention_mask': {
|
||||
# 'input': {},
|
||||
# 'output': {'submod_2': [0]},
|
||||
# 'output_len': 0}},
|
||||
# 'submod_0': {
|
||||
# 'input': {'MODEL_INPUT': [0]},
|
||||
# 'output': {'submod_1': [0], 'submod_2': [0, 1]},
|
||||
# 'output_len': 2},
|
||||
# 'submod_1': {
|
||||
# 'input': {'submod_0': [0], 'MODEL_INPUT': [1]},
|
||||
# 'output': {'submod_2': [0]},
|
||||
# 'output_len': 1},
|
||||
# 'submod_2': {
|
||||
# 'input': {'MODEL_INPUT': [0], 'submod_0': [1, 2]},
|
||||
# 'output': {'submod_3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
# 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
|
||||
# 22, 23, 24]},
|
||||
# 'output_len': 25},
|
||||
# 'submod_3': {
|
||||
# 'input': {'submod_2': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
# 12, 13, 14, 15, 16, 17, 18, 19, 20,
|
||||
# 21, 22, 23, 24]},
|
||||
# 'output': {'MODEL_OUTPUT': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
|
||||
# 11, 12, 13, 14, 15, 16, 17, 18, 19,
|
||||
# 20, 21, 22, 23, 24]},
|
||||
# 'output_len': 25},
|
||||
# 'output_partition': {
|
||||
# 'input': {'logits': 'submod_3', 'past_key_values': (('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
|
||||
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
|
||||
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
|
||||
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
|
||||
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
|
||||
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'))},
|
||||
# 'output': {}, 'output_len': 0}
|
||||
# }
|
||||
|
||||
# TODO(jiangziyue) Define a Class for DAG.
|
||||
def get_DAG(gm: GraphModule):
|
||||
DAG = {}
|
||||
input_partitions = []
|
||||
partitions = []
|
||||
output_partitions = []
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
input_partitions.append(node)
|
||||
elif node.name.startswith('submod_'):
|
||||
partitions.append(node)
|
||||
elif node.op == 'output':
|
||||
output_partitions.append(node)
|
||||
|
||||
for partition in input_partitions:
|
||||
DAG_node = {'input': {}, 'output': {}, 'output_len': 1}
|
||||
_, output, _ = get_partition_depends(partition, partitions, None, output_partitions)
|
||||
DAG_node['output'] = output
|
||||
if 'input_partition' not in DAG:
|
||||
DAG['input_partition'] = {}
|
||||
DAG['input_partition'][partition.name] = DAG_node
|
||||
|
||||
for partition in partitions:
|
||||
DAG_node = {'input': {}, 'output': {}}
|
||||
DAG_node['input'], DAG_node['output'], DAG_node['output_len'] = get_partition_depends(partition, partitions, input_partitions, output_partitions)
|
||||
DAG[partition.name] = DAG_node
|
||||
|
||||
for partition in output_partitions:
|
||||
DAG_node = {'input': {}, 'output': {}, 'output_len': 0}
|
||||
DAG_node['input'] = torch.fx.graph.map_arg(partition.args[0], lambda n: find_def_in_partition(n, partitions, input_partitions))
|
||||
DAG['output_partition'] = DAG_node
|
||||
|
||||
return DAG
|
|
@ -0,0 +1,3 @@
|
|||
from .topo import Topo, Partition, PartitionOutputVal, PartitionInputVal
|
||||
|
||||
__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal']
|
|
@ -0,0 +1,3 @@
|
|||
from .fx import get_topology as get_fx_topology
|
||||
|
||||
__all__ = ['get_fx_topology']
|
|
@ -0,0 +1,145 @@
|
|||
from torch.fx.graph_module import GraphModule
|
||||
from colossalai.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
||||
import torch
|
||||
|
||||
def partition_name_to_id(partition_name, is_input=False, is_output=False):
|
||||
if is_input:
|
||||
partition_id = 0
|
||||
elif is_output:
|
||||
partition_id = 1
|
||||
else:
|
||||
prefix = 'submod_'
|
||||
partition_id = int(partition_name.split(prefix)[-1]) + 2
|
||||
return partition_id
|
||||
|
||||
# There are two kinds of def in fx.graph
|
||||
# 1. non direct_use & non direct_def, which means the output is used by next partition with a temporary mid value.
|
||||
# e.g. submod1 = call_module(...)
|
||||
# temporary_val = submod1[0]
|
||||
# submod2 = call_module(temporary_val, ...)
|
||||
# 2. direct_use & direct_def, which means the output is used by next partition directly.
|
||||
# e.g. submod1 = call_module(...)
|
||||
# submod2 = call_module(submod1, ...)
|
||||
def find_input_in_partition(node, partitions, input_partitions=None):
|
||||
p_input_val = None
|
||||
direct_def = not node.name.startswith('getitem')
|
||||
# search in input
|
||||
if direct_def and input_partitions is not None:
|
||||
partition_id = partition_name_to_id('', is_input=True)
|
||||
for i, input_node in enumerate(input_partitions):
|
||||
if input_node == node:
|
||||
p_input_val = PartitionInputVal(partition_id=partition_id, offset=i)
|
||||
return p_input_val
|
||||
# search submod in mid part
|
||||
if direct_def:
|
||||
for partition in partitions:
|
||||
if partition == node:
|
||||
partition_id = partition_name_to_id(partition.name)
|
||||
p_input_val = PartitionInputVal(partition_id=partition_id, offset=0)
|
||||
return p_input_val
|
||||
# search temporary value in graph
|
||||
else:
|
||||
for partition in partitions:
|
||||
for offset, mid_val in enumerate(partition.users):
|
||||
if mid_val == node:
|
||||
partition_id = partition_name_to_id(partition.name)
|
||||
p_input_val = PartitionInputVal(partition_id=partition_id, offset=offset)
|
||||
return p_input_val
|
||||
|
||||
return p_input_val
|
||||
|
||||
def find_output_in_partition(node, partitions, output_partitions=None):
|
||||
p_output_val = PartitionOutputVal()
|
||||
for user in node.users:
|
||||
direct_use = not user.name.startswith('getitem')
|
||||
# user is mid partition
|
||||
for partition in partitions:
|
||||
# direct call
|
||||
if direct_use:
|
||||
if user == partition:
|
||||
partition_id = partition_name_to_id(partition.name)
|
||||
for i, arg in enumerate(partition.args):
|
||||
if arg == node:
|
||||
p_output_val.add(partition_id=partition_id, offset=i)
|
||||
break
|
||||
# getitem call
|
||||
else:
|
||||
if user in partition.args:
|
||||
partition_id = partition_name_to_id(partition.name)
|
||||
for i, arg in enumerate(partition.args):
|
||||
if arg == user:
|
||||
p_output_val.add(partition_id=partition_id, offset=i)
|
||||
break
|
||||
|
||||
# user is output
|
||||
if output_partitions is not None:
|
||||
output_node = output_partitions[0]
|
||||
if user.op == output_node.op:
|
||||
output_keys = {}
|
||||
partition_id = partition_name_to_id('', is_output=True)
|
||||
torch.fx.graph.map_arg(output_node.args[0], lambda n: output_keys.setdefault(n))
|
||||
for i, arg in enumerate(output_keys):
|
||||
if arg == node:
|
||||
p_output_val.add(partition_id=partition_id, offset=i)
|
||||
break
|
||||
return p_output_val
|
||||
|
||||
def get_topology(gm: GraphModule):
|
||||
topo = Topo()
|
||||
topo_output_partition = Partition()
|
||||
|
||||
input_partitions = []
|
||||
partitions = []
|
||||
output_partitions = []
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
input_partitions.append(node)
|
||||
elif node.name.startswith('submod_'):
|
||||
partitions.append(node)
|
||||
elif node.op == 'output':
|
||||
output_partitions.append(node)
|
||||
else:
|
||||
continue
|
||||
|
||||
# set output for input_partition
|
||||
topo_input_partition = Partition()
|
||||
for partition in input_partitions:
|
||||
cur_node = partition
|
||||
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
|
||||
topo_input_partition.add_output_val(p_output_val)
|
||||
topo.set_partitions(partition_id=0, partition=topo_input_partition)
|
||||
topo.set_input_partition(partition_id=0)
|
||||
|
||||
for i, partition in enumerate(partitions):
|
||||
topo_mid_partition = Partition()
|
||||
# set input for submodule
|
||||
for arg in partition.args:
|
||||
cur_node = arg
|
||||
p_input_val = find_input_in_partition(cur_node, partitions, input_partitions)
|
||||
topo_mid_partition.add_input_val(p_input_val)
|
||||
# set output for submodule
|
||||
direct_use = True
|
||||
for user in partition.users:
|
||||
if user.name.startswith('getitem'):
|
||||
direct_use = False
|
||||
break
|
||||
if direct_use:
|
||||
cur_node = partition
|
||||
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
|
||||
topo_mid_partition.add_output_val(p_output_val)
|
||||
else:
|
||||
for user in partition.users:
|
||||
cur_node = user
|
||||
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
|
||||
topo_mid_partition.add_output_val(p_output_val)
|
||||
topo.set_partitions(partition_id=i+2, partition=topo_mid_partition)
|
||||
|
||||
# set input for output_partition
|
||||
for partition in output_partitions:
|
||||
topo_output_partition = Partition()
|
||||
torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val(
|
||||
find_input_in_partition(n, partitions, input_partitions)))
|
||||
topo.set_partitions(partition_id=1, partition=topo_output_partition)
|
||||
topo.set_output_partition(partition_id=1)
|
||||
|
||||
return topo
|
|
@ -0,0 +1,164 @@
|
|||
from typing import Dict, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
# This file includes data structure used by Pipeline Middleware.
|
||||
|
||||
@dataclass
|
||||
class ValPosition:
|
||||
partition_id: int
|
||||
offset: int
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = f'[partition_id:{self.partition_id},offset:{self.offset}]'
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
class PartitionInputVal(object):
|
||||
def __init__(self, partition_id, offset) -> None:
|
||||
# every input from which partition_id and which offset
|
||||
val_pos = ValPosition(partition_id, offset)
|
||||
self._from_partition_and_offset: ValPosition = val_pos
|
||||
|
||||
def get(self):
|
||||
return self._from_partition_and_offset
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
res += f'<-({self._from_partition_and_offset})'
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
class PartitionOutputVal(object):
|
||||
def __init__(self) -> None:
|
||||
# every output to which partition_id and which offset
|
||||
self._to_partition_and_offset: List[ValPosition] = []
|
||||
|
||||
def add(self, partition_id, offset):
|
||||
val_pos = ValPosition(partition_id, offset)
|
||||
self._to_partition_and_offset.append(val_pos)
|
||||
|
||||
def get(self):
|
||||
return self._to_partition_and_offset
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
res += '->('
|
||||
for val_pos in self._to_partition_and_offset:
|
||||
res += f'{val_pos},'
|
||||
res += ')'
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
class Partition(object):
|
||||
def __init__(self) -> None:
|
||||
self._input_vals: List[PartitionInputVal] = []
|
||||
self._output_vals: List[PartitionOutputVal] = []
|
||||
|
||||
def add_input_val(self, input_val: PartitionInputVal):
|
||||
self._input_vals.append(input_val)
|
||||
|
||||
def add_output_val(self, output_val: PartitionOutputVal):
|
||||
self._output_vals.append(output_val)
|
||||
|
||||
def get_input_vals(self):
|
||||
return self._input_vals
|
||||
|
||||
def get_output_vals(self):
|
||||
return self._output_vals
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
res += f' input:\n'
|
||||
res += f' length:{len(self._input_vals)}\n'
|
||||
for i, input_val in enumerate(self._input_vals):
|
||||
res += f' offset={i}:{input_val}\n'
|
||||
|
||||
res += f' output:\n'
|
||||
res += f' length:{len(self._output_vals)}\n'
|
||||
for i, output_val in enumerate(self._output_vals):
|
||||
res += f' offset={i}:{output_val}\n'
|
||||
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
# This class is a middleware between partition splitter
|
||||
# and Pipeline Scheduler. It records the graph info about
|
||||
# partition input/output and provides it to scheduler.
|
||||
# There are three kinds of partition in Pipeline Middleware Design
|
||||
# which represents the whole process of a model execution: input-fwd-output
|
||||
# 1. input_partition: records the input of a model.
|
||||
# 2. mid_partition: record the splitted forwards execution of a model.
|
||||
# 3. output_partition: records the output of a model.
|
||||
# attributes:
|
||||
# _partitions: include all partitions
|
||||
# _input_partition_id: the key represents input_partition
|
||||
# _output_partition_id: the key represents output_partition
|
||||
class Topo(object):
|
||||
def __init__(self, input_partition_id=None, output_partition_id=None) -> None:
|
||||
self._partitions: Dict[int, Partition] = {}
|
||||
self._input_partition_id = input_partition_id
|
||||
self._output_partition_id = output_partition_id
|
||||
|
||||
def set_input_partition(self, partition_id: int):
|
||||
self._input_partition_id = partition_id
|
||||
|
||||
def set_output_partition(self, partition_id: int):
|
||||
self._output_partition_id = partition_id
|
||||
|
||||
def set_partitions(self, partition_id: int, partition: Partition):
|
||||
self._partitions[partition_id] = partition
|
||||
|
||||
def get_mid_partitions(self):
|
||||
res = {} #{partition_id: Partition}
|
||||
for partition_id, partition in self._partitions.items():
|
||||
if self._input_partition_id == partition_id or self._output_partition_id == partition_id:
|
||||
continue
|
||||
res[partition_id] = partition
|
||||
return res
|
||||
|
||||
def get_input_partition(self):
|
||||
if self._input_partition_id is not None:
|
||||
return self._partitions[self._input_partition_id]
|
||||
return None
|
||||
|
||||
def get_output_partition(self):
|
||||
if self._output_partition_id is not None:
|
||||
return self._partitions[self._output_partition_id]
|
||||
return None
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
if len(self._partitions) == 0:
|
||||
return 'Empty Topo Graph.'
|
||||
|
||||
input_part = self.get_input_partition()
|
||||
if input_part is not None:
|
||||
res += '{\n'
|
||||
res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}'
|
||||
res += '}\n'
|
||||
|
||||
mid_parts = self.get_mid_partitions()
|
||||
for i, (partition_id, part) in enumerate(mid_parts.items()):
|
||||
res += '{\n'
|
||||
res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}'
|
||||
res += '}\n'
|
||||
|
||||
output_part = self.get_output_partition()
|
||||
if output_part is not None:
|
||||
res += '{\n'
|
||||
res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}'
|
||||
res += '}\n'
|
||||
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
|
@ -1,85 +0,0 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||
from colossalai.fx import ColoTracer
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
MANUAL_SEED = 0
|
||||
random.seed(MANUAL_SEED)
|
||||
np.random.seed(MANUAL_SEED)
|
||||
torch.manual_seed(MANUAL_SEED)
|
||||
|
||||
def split_model_and_get_DAG(model, data_gen):
|
||||
model.eval()
|
||||
|
||||
# generate input sample
|
||||
kwargs = data_gen()
|
||||
|
||||
# get origin output and rng state
|
||||
cpu_rng_state = torch.get_rng_state()
|
||||
output = model(**kwargs)
|
||||
|
||||
# tracing model
|
||||
tracer = ColoTracer()
|
||||
try:
|
||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
# apply transform passes
|
||||
annotated_model = balanced_split_pass(gm, 2)
|
||||
top_module, split_submodules = split_with_split_nodes_pass(annotated_model)
|
||||
|
||||
return top_module, split_submodules[0]._DAG
|
||||
|
||||
def check_input(input, input_node, top_module):
|
||||
for user in input_node.users.keys():
|
||||
partition_name = user.name
|
||||
assert partition_name in input['output']
|
||||
|
||||
def check_submod(submod_partition, node, top_module):
|
||||
for arg in node.args:
|
||||
input_part_name = None
|
||||
if arg.op == 'placeholder':
|
||||
input_part_name = 'MODEL_INPUT'
|
||||
elif not arg.name.startswith('getitem'):
|
||||
input_part_name = arg.name
|
||||
else:
|
||||
input_part_name = arg.args[0].name
|
||||
assert input_part_name in submod_partition['input']
|
||||
|
||||
for user in node.users:
|
||||
output_part_names = []
|
||||
if user.op == 'output':
|
||||
output_part_names.append('MODEL_OUTPUT')
|
||||
elif not user.name.startswith('getitem'):
|
||||
output_part_names.append(user.name)
|
||||
else:
|
||||
for n in user.users:
|
||||
if n.op == 'output':
|
||||
output_part_names.append('MODEL_OUTPUT')
|
||||
else:
|
||||
output_part_names.append(n.name)
|
||||
|
||||
for output_part_name in output_part_names:
|
||||
assert output_part_name in submod_partition['output']
|
||||
|
||||
def check_DAG(top_module, DAG):
|
||||
assert 'input_partition' in DAG
|
||||
input_partition = DAG['input_partition']
|
||||
|
||||
for node in top_module.graph.nodes:
|
||||
# check input
|
||||
if node.op == 'placeholder':
|
||||
assert node.name in input_partition
|
||||
input = input_partition[node.name]
|
||||
check_input(input, node, top_module)
|
||||
elif node.op == 'call_module':
|
||||
assert node.name in DAG
|
||||
submod_partition = DAG[node.name]
|
||||
check_submod(submod_partition, node, top_module)
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from dag_utils import split_model_and_get_DAG, check_DAG
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
@pytest.mark.skip('balance split v2 is not ready')
|
||||
def test_opt():
|
||||
MODEL_LIST = [
|
||||
transformers.OPTModel,
|
||||
#transformers.OPTForCausalLM,
|
||||
]
|
||||
|
||||
config = transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4)
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return kwargs
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
model = model_cls(config=config)
|
||||
top_mod, DAG = split_model_and_get_DAG(model, data_gen)
|
||||
check_DAG(top_mod, DAG)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_opt()
|
|
@ -0,0 +1,43 @@
|
|||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from topo_utils import split_model_and_get_DAG, check_topo, MLP
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
def test_opt():
|
||||
MODEL_LIST = [
|
||||
MLP,
|
||||
transformers.OPTModel,
|
||||
]
|
||||
|
||||
CONFIGS = [
|
||||
{'dim': 10, 'layers': 12},
|
||||
transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4),
|
||||
]
|
||||
|
||||
def data_gen_MLP():
|
||||
x = torch.zeros((16, 10))
|
||||
kwargs = dict(x=x)
|
||||
return kwargs
|
||||
|
||||
def data_gen_OPT():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return kwargs
|
||||
|
||||
DATAGEN = [
|
||||
data_gen_MLP,
|
||||
data_gen_OPT,
|
||||
]
|
||||
|
||||
for i, model_cls in enumerate(MODEL_LIST):
|
||||
model = model_cls(config=CONFIGS[i])
|
||||
top_mod, topo = split_model_and_get_DAG(model, DATAGEN[i])
|
||||
# print(f'{top_mod=}\n----\n{topo=}')
|
||||
check_topo(top_mod, topo)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_opt()
|
|
@ -0,0 +1,92 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
||||
from colossalai.pipeline.middleware.adaptor import get_fx_topology
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
MANUAL_SEED = 0
|
||||
random.seed(MANUAL_SEED)
|
||||
np.random.seed(MANUAL_SEED)
|
||||
torch.manual_seed(MANUAL_SEED)
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, config={}):
|
||||
super().__init__()
|
||||
dim = config['dim']
|
||||
layers = config['layers']
|
||||
self.layers = torch.nn.ModuleList()
|
||||
|
||||
for _ in range(layers):
|
||||
self.layers.append(torch.nn.Linear(dim, dim, bias=False))
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
def split_model_and_get_DAG(model, data_gen):
|
||||
model.eval()
|
||||
|
||||
# generate input sample
|
||||
kwargs = data_gen()
|
||||
|
||||
# tracing model
|
||||
tracer = ColoTracer()
|
||||
try:
|
||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
# apply transform passes
|
||||
annotated_model = balanced_split_pass(gm, 2)
|
||||
top_module, split_submodules = split_with_split_nodes_pass(annotated_model)
|
||||
|
||||
topo = get_fx_topology(top_module)
|
||||
for submodule in split_submodules:
|
||||
if isinstance(submodule, torch.fx.GraphModule):
|
||||
setattr(submodule, '_topo', topo)
|
||||
|
||||
return top_module, split_submodules[0]._topo
|
||||
|
||||
def check_input(top_module, input_partition: Partition):
|
||||
partition_output = input_partition.get_output_vals()
|
||||
arg_pos = 0
|
||||
for node in top_module.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
cur_checkee = partition_output[arg_pos]
|
||||
to_partition_and_offset = cur_checkee.get()
|
||||
assert len(to_partition_and_offset) == len(node.users.keys())
|
||||
arg_pos += 1
|
||||
|
||||
assert arg_pos == len(partition_output)
|
||||
|
||||
def check_submod(top_module, part_id, mid_partition: Partition):
|
||||
partition_input = mid_partition.get_input_vals()
|
||||
partition_output = mid_partition.get_output_vals()
|
||||
|
||||
cnt = 1
|
||||
cur_node = None
|
||||
for node in top_module.graph.nodes:
|
||||
if node.name.startswith('submod'):
|
||||
cnt += 1
|
||||
if cnt == part_id:
|
||||
cur_node = node
|
||||
break
|
||||
|
||||
assert len(partition_input) == len(cur_node.args)
|
||||
assert len(partition_output) == len(cur_node.users)
|
||||
|
||||
def check_topo(top_module, topo: Topo):
|
||||
input_partition = topo.get_input_partition()
|
||||
mid_partitions = topo.get_mid_partitions()
|
||||
|
||||
check_input(top_module, input_partition)
|
||||
for part_id, submod in mid_partitions.items():
|
||||
check_submod(top_module, part_id, submod)
|
||||
|
Loading…
Reference in New Issue