pass outproduct mean

pull/2364/head
oahzxl 2022-12-10 17:29:51 +08:00
parent 979e61db92
commit 929445116a
1 changed files with 211 additions and 104 deletions

View File

@ -16,16 +16,31 @@ def _delete_free_var_from_last_use(user_to_last_uses):
if n.op == 'placeholder':
user_to_last_uses[key].remove(n)
def _get_node_shape(node):
if hasattr(node.meta['tensor_meta'], "shape"):
return node.meta['tensor_meta'].shape
return None
def _is_non_compute_node(node):
if any(i in node.op for i in ['placeholder', 'get_attr', 'output']) or \
any(i in node.name for i in ['getitem', 'getattr']):
return True
return False
def _is_non_compute_node_except_placeholder(node):
if any(i in node.op for i in ['get_attr', 'output']) or \
any(i in node.name for i in ['getitem', 'getattr']):
return True
return False
class FlowTracer(object):
def __init__(self, gm) -> None:
self.gm = gm
self.nodes_list = list(gm.graph.nodes)
self.node_list = list(gm.graph.nodes)
self.flow_trace = {}
def _add_trace(self, name):
@ -49,7 +64,7 @@ class FlowTracer(object):
raise RuntimeError("node not found")
def _init_trace(self):
for i in self.nodes_list:
for i in self.node_list:
if i.op == 'placeholder':
self._add_trace(i.name)
self._add_node(i.name, i)
@ -67,7 +82,7 @@ class FlowTracer(object):
return False
def _find_flow_for_node(self, node):
if type(self.nodes_list[0]) != type(node):
if type(self.node_list[0]) != type(node):
return None
if self._is_non_compute_node_except_placeholder(node):
return None
@ -117,7 +132,7 @@ class FlowTracer(object):
# init trace
self._init_trace()
for node in self.nodes_list:
for node in self.node_list:
# skip if non compute node
if all(type(arg) != type(node) or self._is_non_compute_node_except_placeholder(arg) for arg in node.args) \
or self._is_non_compute_node(node):
@ -136,6 +151,41 @@ class FlowTracer(object):
self._add_outside_depend(node_domin_flow, node, arg, node_input_flow)
return self.flow_trace
def _detect_flow(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = _find_chunk_input_and_output_nodes(self.node_list[start_idx:end_idx + 1])
chunk_info = {'region': (start_idx, end_idx),
'inputs': inputs, 'inputs_dim': start_dim,
'outputs': outputs, 'outputs_dim': end_dim,
'args': {}}
flow_flag = False
for idx in range(start_idx, end_idx + 1):
node = self.node_list[idx]
mix_flow_var = self.get_flow_mix(node)
if mix_flow_var is None:
continue
# if there is a flow mix, op must be in [mul, add, div, matmul]
# element-wise op requires dim to be equal in every dim
if any(n in node.name for n in ['mul', 'add']):
for i in node.args:
if type(i) == type(mix_flow_var) and i != mix_flow_var:
main_flow_var = i
# if mix flow is a broadcast in chunk dim,
# TODO need to move that flow out of the chunk
if mix_flow_var.meta['tensor_meta'].shape[dim_idx] == 1:
flow_flag = True
for i in self.get_same_flow_node(chunk_info['inputs'], mix_flow_var):
chunk_info['inputs'].remove(i)
# else, we need to chunk mix var as well
else:
# TODO chunk another value
flow_flag = False
break
else:
raise NotImplementedError("%s not implemented" % node.name)
return flow_flag, chunk_info
class IndexTracer(object):
def __init__(self, gm) -> None:
@ -153,7 +203,7 @@ class IndexTracer(object):
cur_trace = {
'idx': [None for _ in range(len(_get_node_shape(n)))],
'compute': [[] for _ in range(len(_get_node_shape(n)))],
'source': [[] for _ in range(len(_get_node_shape(n)))],
'source': [{} for _ in range(len(_get_node_shape(n)))],
}
else:
cur_trace = {'idx': [], 'compute': [], 'source': []}
@ -178,7 +228,7 @@ class IndexTracer(object):
def _add_dim(self, idx, dim_idx):
self.idx_trace_list[idx]['idx'].insert(dim_idx, self._add_index())
self.idx_trace_list[idx]['compute'].insert(dim_idx, [])
self.idx_trace_list[idx]['source'].insert(dim_idx, [])
self.idx_trace_list[idx]['source'].insert(dim_idx, {})
def _transform_index(self, node, node_dim):
node_idx = self._find_idx_trace_from_node(node)
@ -192,10 +242,7 @@ class IndexTracer(object):
node_to_trace = self._find_trace_from_node(node_to)
node_to_trace['idx'][node_to_dim] = node_from_trace['idx'][node_from_dim]
node_to_trace['compute'][node_to_dim] = copy.deepcopy(node_from_trace['compute'][node_from_dim])
node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list)
node_to_trace['source'][node_to_dim] = []
node_to_trace['source'][node_to_dim].append({node_from_idx: node_from_dim})
node_to_trace['source'][node_to_dim].extend(node_from_trace['source'][node_from_dim])
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True)
def _inherit_all_computation(self, node_from, node_to):
node_from_compute = self._find_compute_trace_from_node(node_from)
@ -205,14 +252,16 @@ class IndexTracer(object):
self._add_source(node_from, i, node_to, i)
node_to_compute[i] = copy.deepcopy(node_from_compute[i])
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim):
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False):
node_from_dim = self._transform_index(node_from, node_from_dim)
node_from_trace = self._find_trace_from_node(node_from)
node_to_dim = self._transform_index(node_to, node_to_dim)
node_to_trace = self._find_trace_from_node(node_to)
node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list)
node_to_trace['source'][node_to_dim].append({node_from_idx: node_from_dim})
node_to_trace['source'][node_to_dim].extend(node_from_trace['source'][node_from_dim])
if init:
node_to_trace['source'][node_to_dim] = {}
node_to_trace['source'][node_to_dim][node_from_idx] = node_from_dim
node_to_trace['source'][node_to_dim].update(node_from_trace['source'][node_from_dim])
def _mark_computation_from_node(self, node_from, node_to, exclude=None):
if exclude == None:
@ -485,11 +534,11 @@ class IndexTracer(object):
source_idx = left_str.index(right_indice)
self._inherit_index(input_nodes[left_idx], source_idx, node, right_idx)
for i in sum_index:
for left_idx, left_str in enumerate(left):
if i in left_str:
self._mark_computation(node, idx, left_str.index(i))
break
# for i in sum_index:
# for left_idx, left_str in enumerate(left):
# if i in left_str:
# self._mark_computation(node, idx, left_str.index(i))
# break
def _assign_softmax_index(self, node, idx):
"""
@ -679,18 +728,56 @@ class IndexTracer(object):
raise NotImplementedError(node.op, "op not implemented yet!")
# self._merge_equal_idx()
def check_index(self, trace_idx, start_idx, end_idx):
for i in range(start_idx, end_idx + 1):
cur_idx = self.idx_trace_list[i]['idx']
cur_compute = self.idx_trace_list[i]['compute']
if trace_idx in cur_compute:
for j in cur_compute[trace_idx]:
if j < start_idx or j > end_idx:
return False
# same_idx = [1 if j == trace_idx else 0 for j in cur_idx]
# if sum(same_idx) > 1:
# return False
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
"""
Check 2 given index: one index should be source of the other
Args:
start_idx(int): start node chunk dim
start_node(node): start node
end_idx(int): end node chunk dim
end_node(node): end node
Returns:
bool: True if check pass
"""
start_node_idx = _find_idx_by_name(start_node.name, self.nodes_list)
end_node_trace = self._find_trace_from_node(end_node)
end_node_trace_source = end_node_trace['source'][end_dim]
sorted_source = sorted(end_node_trace_source.items(), key=lambda d:d[0], reverse=True)
for node_idx, node_dim in sorted_source:
if node_idx == start_node_idx and node_dim == start_dim:
return True
# it means we meet a node outside the loop, and the node is not input node
if node_idx < start_idx:
return False
return False
def check_index_compute(self, start_idx, end_dim, end_node, end_idx):
"""
Check 2 given index: check they haven't been computed in the source trace.
Args:
start_idx(int): start node chunk dim
start_node(node): start node
end_idx(int): end node chunk dim
end_node(node): end node
Returns:
bool: True if check pass
"""
end_node_trace = self._find_trace_from_node(end_node)
end_node_compute = end_node_trace['compute'][end_dim]
if any(start_idx <= i <= end_idx for i in end_node_compute):
return False
return True
# end_node_trace_source = end_node_trace['source'][end_dim]
# for node_idx, node_dim in end_node_trace_source.items():
# if node_idx < start_node_idx or node_idx > end_node_idx:
# continue
# compute_list = self.idx_trace_list[node_idx]['compute'][node_dim]
# if any(start_node_idx <= i <= end_node_idx for i in compute_list):
# return False
# return True
class MemoryEstimator(object):
def __init__(self) -> None:
@ -951,88 +1038,81 @@ class ChunkRegionSearch(object):
return True
return False
def _detect_flow(self, before_trace, after_trace, start_idx, end_idx, dim_idx):
inputs, outputs = _find_input_and_output_nodes(self.node_list[start_idx:end_idx + 1])
chunk_info = {'inputs': inputs, 'outputs': outputs}
flow_flag = False
for idx in range(start_idx, end_idx + 1):
node = self.node_list[idx]
mix_flow_var = self.flow_tracer.get_flow_mix(node)
if mix_flow_var is None:
def _check_duplicate_map(self, chunk_infos):
dim_map = [(i['inputs_dim'], i['outputs_dim']) for i in chunk_infos]
remove_list = []
for idx1, (input_dim1, output_dim1) in enumerate(dim_map):
for idx2, (input_dim2, output_dim2) in enumerate(dim_map):
if idx1 == idx2:
continue
# if there is a flow mix, op must be in [mul, add, div, matmul]
# element-wise op requires dim to be equal in every dim
if any(n in node.name for n in ['mul', 'add']):
for i in node.args:
if type(i) == type(mix_flow_var) and i != mix_flow_var:
main_flow_var = i
# if mix flow is a broadcast in chunk dim,
# TODO need to move that flow out of the chunk
if mix_flow_var.meta['tensor_meta'].shape[dim_idx] == 1:
flow_flag = True
for i in self.flow_tracer.get_same_flow_node(chunk_info['inputs'], mix_flow_var):
chunk_info['inputs'].remove(i)
# else, we need to chunk mix var as well
else:
# TODO chunk another value
flow_flag = False
break
else:
raise NotImplementedError("%s not implemented" % node.name)
return flow_flag, chunk_info
# it means an index create 2 copy of itself
# eg. a = torch.matmul(x, x.transpose(-1, -2))
# TODO currently remove it, deal with this in future
if input_dim1 == input_dim2 and output_dim1 != output_dim2:
remove_list.append(chunk_infos[idx1])
remove_list.append(chunk_infos[idx2])
for i in remove_list:
if i in chunk_infos:
chunk_infos.remove(i)
return chunk_infos
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
before_trace = input_trace[start_idx]
after_trace = output_trace[end_idx]
free_dim = []
start_traces = input_trace[start_idx]
end_trace = output_trace[end_idx]
end_node = self.node_list[end_idx]
chunk_infos = []
for i in range(min(len(before_trace['idx']), len(after_trace['idx']))):
if not (before_trace['idx'][i] == after_trace['idx'][i] and
self._is_not_compute(before_trace, (start_idx, end_idx), i) and
self._is_not_compute(after_trace, (start_idx, end_idx), i) and
self.node_list[end_idx].meta['tensor_meta'].shape[i] != 1):
for end_dim, end_trace_idx in enumerate(end_trace['idx']):
if len(start_traces) > 1:
# TODO implement multi input chunk
continue
if not self.index_tracer.check_index(before_trace['idx'][i], start_idx, end_idx):
for start_node, start_trace in start_traces.items():
for start_dim, start_trace_idx in enumerate(start_trace['idx']):
# must be same trace idx
if start_trace_idx != end_trace_idx:
continue
flow_flag, chunk_info = self._detect_flow(before_trace, after_trace, start_idx, end_idx, i)
if flow_flag == None:
# dim size cannot be 1
if _get_node_shape(end_node)[end_dim] == 1 or \
_get_node_shape(start_node)[start_dim] == 1:
continue
# check index source align
if not self.index_tracer.check_index_source(
start_dim, start_node, start_idx, end_dim, end_node):
continue
# check index copmute
if not self.index_tracer.check_index_compute(
start_idx, end_dim, end_node, end_idx):
continue
# detect flow meet
flow_flag, chunk_info = self.flow_tracer._detect_flow(
start_idx, start_dim, end_idx, end_dim)
if flow_flag:
continue
chunk_infos.append(chunk_info)
free_dim.append(i)
return free_dim, chunk_infos
chunk_infos = self._check_duplicate_map(chunk_infos)
return chunk_infos
def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
possible_chunk_region = []
output_trace = copy.deepcopy(self.index_tracer.idx_trace_list)
input_trace = []
for i, n in enumerate(self.node_list):
if len(n.args) > 0 and n.op != 'output':
if isinstance(n.args[0], str):
input_idx = _find_idx_by_name(n.args[1].name, self.node_list)
else:
input_idx = _find_idx_by_name(n.args[0].name, self.node_list)
input_trace.append(output_trace[input_idx])
else:
input_trace.append(None)
input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.node_list):
cur_trace = {}
for arg in n.args:
if type(arg) == type(n) and not _is_non_compute_node_except_placeholder(arg):
cur_trace[arg] = self.index_tracer._find_trace_from_node(arg)
input_trace.append(cur_trace)
for start_idx in range(max_chunk_region[0], peak_node):
for start_idx in range(max_chunk_region[0], peak_node + 1):
for end_idx in range(peak_node, max_chunk_region[1] + 1):
# skip non compute nodes
if any(op in ['placeholder', 'get_attr', 'output'] for op in
[self.node_list[start_idx].op, self.node_list[end_idx].op]):
continue
if any(any(i in name for i in ['getitem', 'getattr']) for name in
[self.node_list[start_idx].name, self.node_list[end_idx].name]):
if _is_non_compute_node(self.node_list[start_idx]) or \
_is_non_compute_node(self.node_list[end_idx]):
continue
# select free dim
free_dim, chunk_info = self._find_free_dim(input_trace, output_trace, start_idx, end_idx)
if len(free_dim) > 0:
free_dim = [free_dim[0]]
chunk_info = [chunk_info[0]]
possible_chunk_region.append({'region': (start_idx, end_idx), 'dim': free_dim, 'chunk_info': chunk_info})
chunk_info = self._find_free_dim(input_trace, output_trace, start_idx, end_idx)
if len(chunk_info) > 0:
possible_chunk_region.extend(chunk_info)
return possible_chunk_region
def _search_best_chunk_region(self, possible_chunk_regions):
@ -1044,7 +1124,8 @@ class ChunkRegionSearch(object):
max_region_range = i['region'][1] - i['region'][0]
return best_regions
def _step_search(self, peak_node, active_node):
def _step_search(self, mem_peak, active_node):
peak_node = self._find_peak_node(mem_peak)
max_chunk_region = self._search_max_chunk_region(active_node, peak_node)
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
best_chunk_region = self._search_best_chunk_region(possible_chunk_regions)
@ -1062,19 +1143,16 @@ class ChunkRegionSearch(object):
mem_peak = init_mem_peak
while True:
peak_node = self._find_peak_node(mem_peak)
chunk_region = self._step_search(peak_node, active_node)
if chunk_region is None or len(chunk_region['dim']) == 0:
chunk_region = self._step_search(mem_peak, active_node)
if chunk_region is None:
break
chunk_regions.append(chunk_region)
mem_peak, _, active_node = self.memory_estimator.estimate_chunk_inference_mem(
self.gm, [i['region'][0] for i in chunk_regions],
[i['region'][1] for i in chunk_regions], [i['dim'][0] for i in chunk_regions], [1] * len(chunk_regions))
[i['region'][1] for i in chunk_regions], [i['inputs_dim'] for i in chunk_regions], [1] * len(chunk_regions))
if self._stop_search(init_mem_peak, mem_peak):
break
return chunk_regions
@ -1164,6 +1242,35 @@ def _find_input_and_output_nodes(nodes: List[Node]):
return input_nodes, output_nodes
def _find_chunk_input_and_output_nodes(nodes: List[Node]):
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
output nodes are nodes will use nodes in the list
"""
input_nodes = []
output_nodes = []
# if a node has an input node which is not in the node list
# we treat that input node as the input of the checkpoint function
for node in nodes:
for input_node in node._input_nodes.keys():
if input_node not in nodes and input_node not in input_nodes \
and not _is_non_compute_node_except_placeholder(input_node):
input_nodes.append(input_node)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
# TODO it is unsafe to remove non compute node here
for node in nodes:
for output_node in node.users.keys():
if output_node not in nodes and node not in output_nodes \
and not _is_non_compute_node_except_placeholder(input_node):
output_nodes.append(node)
return input_nodes, output_nodes
def _find_idx_by_name(name, nodes_list):
for idx, node in enumerate(nodes_list):
if node.name == name: