remove flow tracer

pull/2364/head
oahzxl 2022-12-23 15:34:41 +08:00
parent 4d89525fc2
commit 4f5e105af3
1 changed files with 27 additions and 144 deletions

View File

@ -67,7 +67,7 @@ def _is_non_compute_node_except_placeholder_output(node):
class IndexTracer(object):
def __init__(self, gm) -> None:
self.gm = gm
self.nodes_list = list(gm.graph.nodes)
self.node_list = list(gm.graph.nodes)
self.idx_trace_list = self._init_idx_trace_list()
self.idx_trace_equal = []
self.idx_view_list = []
@ -75,7 +75,7 @@ class IndexTracer(object):
def _init_idx_trace_list(self):
idx_trace_list = []
for n in self.nodes_list:
for n in self.node_list:
if _get_node_shape(n) != None:
cur_trace = {
"idx": [None for _ in range(len(_get_node_shape(n)))],
@ -136,7 +136,7 @@ class IndexTracer(object):
node_from_trace = self._find_trace_from_node(node_from)
node_to_dim = self._transform_index(node_to, node_to_dim)
node_to_trace = self._find_trace_from_node(node_to)
node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list)
node_from_idx = _find_idx_by_name(node_from.name, self.node_list)
if init:
node_to_trace["source"][node_to_dim] = {}
# add dim to cur new source
@ -210,7 +210,7 @@ class IndexTracer(object):
idx (list): idx of the node
compute (list): computed idx of the node.
"""
node_idx = _find_idx_by_name(node.name, self.nodes_list)
node_idx = _find_idx_by_name(node.name, self.node_list)
node_dict = self.idx_trace_list[node_idx]
return node_dict
@ -224,7 +224,7 @@ class IndexTracer(object):
idx (list): idx of the node
compute (list): computed idx of the node.
"""
node_idx = _find_idx_by_name(node.name, self.nodes_list)
node_idx = _find_idx_by_name(node.name, self.node_list)
node_dict = self.idx_trace_list[node_idx]
return node_dict["source"]
@ -237,7 +237,7 @@ class IndexTracer(object):
Returns:
idx (list): idx of the node
"""
node_idx = _find_idx_by_name(node.name, self.nodes_list)
node_idx = _find_idx_by_name(node.name, self.node_list)
return self.idx_trace_list[node_idx]["idx"]
def _find_compute_trace_from_node(self, node):
@ -249,7 +249,7 @@ class IndexTracer(object):
Returns:
compute (list): computed idx of the node.
"""
node_idx = _find_idx_by_name(node.name, self.nodes_list)
node_idx = _find_idx_by_name(node.name, self.node_list)
return self.idx_trace_list[node_idx]["compute"]
def _assign_index_as_input(self, node, node_idx, input_node=None):
@ -262,7 +262,7 @@ class IndexTracer(object):
"""
if input_node == None:
input_node = node.args[0]
input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list)
input_node_idx = _find_idx_by_name(input_node.name, self.node_list)
input_node_idx_trace = self.idx_trace_list[input_node_idx]["idx"]
new_idx_trace = copy.deepcopy(input_node_idx_trace)
@ -591,7 +591,7 @@ class IndexTracer(object):
]
def trace_index(self):
for idx, node in enumerate(self.nodes_list):
for idx, node in enumerate(self.node_list):
if node.op == "placeholder":
self._assign_all_index(node, idx)
elif node.op == "call_method":
@ -655,7 +655,7 @@ class IndexTracer(object):
Returns:
bool: True if check pass
"""
start_node_idx = _find_idx_by_name(start_node.name, self.nodes_list)
start_node_idx = _find_idx_by_name(start_node.name, self.node_list)
end_node_trace = self._find_trace_from_node(end_node)
end_node_trace_source = end_node_trace["source"][end_dim]
sorted_source = sorted(
@ -690,14 +690,14 @@ class IndexTracer(object):
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
node_from_source = self._find_source_trace_from_node(node_from)
dim_source = node_from_source[node_from_dim]
node_to_idx = _find_idx_by_name(node_to.name, self.nodes_list)
node_to_idx = _find_idx_by_name(node_to.name, self.node_list)
for k, v in dim_source.items():
if k == node_to_idx:
return v
return None
def _find_inherit_dim(self, input_node, input_dim, node):
input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list)
input_node_idx = _find_idx_by_name(input_node.name, self.node_list)
node_trace_source = self._find_source_trace_from_node(node)
for node_dim in range(len(_get_node_shape(node))):
if (
@ -711,11 +711,11 @@ class IndexTracer(object):
input_dim_after_node = {}
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
inherit_dim = self._find_inherit_dim(input_node, v, self.nodes_list[k])
inherit_dim = self._find_inherit_dim(input_node, v, self.node_list[k])
if inherit_dim:
input_dim_after_node[k] = inherit_dim
for node in self.nodes_list[
for node in self.node_list[
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
]:
if _is_non_compute_node_except_placeholder(node):
@ -746,124 +746,11 @@ class IndexTracer(object):
else:
return True
class FlowTracer(object):
def __init__(self, gm) -> None:
self.gm = gm
self.node_list = list(gm.graph.nodes)
self.flow_trace = {}
def _add_trace(self, name):
self.flow_trace[name] = []
def _add_node(self, trace_name, node):
self.flow_trace[trace_name].append(
{"node": node, "inside_depend": [], "outside_depend": []}
)
def _add_inside_depend(self, flow_name, node, inside_depend_node):
for i in self.flow_trace[flow_name]:
if i["node"] == node:
i["inside_depend"].append(inside_depend_node)
return
raise RuntimeError("node not found")
def _add_outside_depend(
self, flow_name, node, outside_depend_node, outside_depend_trace
):
for i in self.flow_trace[flow_name]:
if i["node"] == node:
i["outside_depend"].append({outside_depend_trace: outside_depend_node})
return
raise RuntimeError("node not found")
def _init_trace(self):
for i in self.node_list:
if i.op == "placeholder":
self._add_trace(i.name)
self._add_node(i.name, i)
def _find_flow_for_node(self, node):
if type(self.node_list[0]) != type(node):
return None
if _is_non_compute_node_except_placeholder(node):
return None
for name, trace in self.flow_trace.items():
for i in trace:
if node == i["node"]:
return name
if any(i in node.name for i in ["ones_like"]):
self._add_trace(node.name)
self._add_node(node.name, node)
return node.name
raise RuntimeError("node not found")
def _find_first_valid_flow(self, flow):
for i in flow:
if i is not None:
return i
raise RuntimeError("invalid flow")
def find_node_flow(self, node):
for name, trace in self.flow_trace.items():
for i in trace:
if node == i["node"]:
return name, i
raise RuntimeError("invalid node")
def _get_flow_mix_node(self, node):
if _is_non_compute_node(node):
return None
_, node_trace = self.find_node_flow(node)
if len(node_trace["outside_depend"]) == 0:
return None
elif len(node_trace["outside_depend"]) > 1:
raise NotImplementedError
vars = list(node_trace["outside_depend"][0].values())[0]
return vars
def _get_same_flow_node(self, node_list, node):
name, _ = self.find_node_flow(node)
result = []
for i in self.flow_trace[name]:
if i["node"] in node_list:
result.append(i["node"])
return result
def trace_flow(self):
# init trace
self._init_trace()
for node in self.node_list:
# skip if non compute node
if all(
type(arg) != type(node) or _is_non_compute_node_except_placeholder(arg)
for arg in node.args
) or _is_non_compute_node(node):
continue
node_input_flows = [self._find_flow_for_node(arg) for arg in node.args]
node_domin_flow = self._find_first_valid_flow(node_input_flows)
self._add_node(node_domin_flow, node)
for node_input_flow, arg in zip(node_input_flows, node.args):
if node_input_flow is None:
continue
elif node_input_flow == node_domin_flow:
self._add_inside_depend(node_domin_flow, node, arg)
else:
self._add_outside_depend(
node_domin_flow, node, arg, node_input_flow
)
return self.flow_trace
def _assgin_single_node_flow(
self,
arg_node,
start_idx,
end_idx,
inputs,
index_tracer,
cur_node_dim,
cur_node_compute,
cur_node_source,
@ -871,7 +758,7 @@ class FlowTracer(object):
all_node_info,
next_node_list,
):
arg_idx = _find_idx_by_name(arg_node.name, index_tracer.nodes_list)
arg_idx = _find_idx_by_name(arg_node.name, self.node_list)
# arg in chunk range or be inputs
if not (start_idx <= arg_idx < end_idx):
return True
@ -911,7 +798,7 @@ class FlowTracer(object):
return True
def flow_search(
self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer
self, start_idx, start_dim, end_idx, end_dim
):
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
self.node_list[start_idx : end_idx + 1]
@ -920,7 +807,7 @@ class FlowTracer(object):
if len(outputs) > 1:
return None
cur_node_list = [index_tracer.nodes_list[end_idx]] # start from the last node
cur_node_list = [self.node_list[end_idx]] # start from the last node
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
while len(cur_node_list) > 0:
@ -930,12 +817,12 @@ class FlowTracer(object):
# get cur node info
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
cur_node_idx = _find_idx_by_name(cur_node.name, index_tracer.nodes_list)
cur_node_idx = _find_idx_by_name(cur_node.name, self.node_list)
if cur_node_chunk_dim:
cur_node_compute = index_tracer._find_compute_trace_from_node(
cur_node_compute = self._find_compute_trace_from_node(
cur_node
)
cur_node_source = index_tracer._find_source_trace_from_node(
cur_node_source = self._find_source_trace_from_node(
cur_node
)
else:
@ -953,8 +840,6 @@ class FlowTracer(object):
arg,
start_idx,
end_idx,
inputs,
index_tracer,
cur_node_chunk_dim,
cur_node_compute,
cur_node_source,
@ -970,7 +855,7 @@ class FlowTracer(object):
for arg in arg_list:
if not (
start_idx
<= _find_idx_by_name(arg.name, index_tracer.nodes_list)
<= _find_idx_by_name(arg.name, self.node_list)
< end_idx
):
continue
@ -1029,7 +914,7 @@ class FlowTracer(object):
if node_info["chunk_dim"] is None:
maybe_prepose_nodes.append(node)
maybe_prepose_nodes.sort(
key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list),
key=lambda x: _find_idx_by_name(x.name, self.node_list),
reverse=True,
) # from last node to first node
prepose_nodes = []
@ -1081,7 +966,7 @@ class FlowTracer(object):
maybe_prepose_nodes.remove(n)
# sort by index
prepose_nodes.sort(
key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list)
key=lambda x: _find_idx_by_name(x.name, self.node_list)
)
chunk_info["args"]["prepose_nodes"] = prepose_nodes
@ -1226,9 +1111,9 @@ class MemoryEstimator(object):
for k, v in input_node_dim.items():
# TODO: inherit dim should be list too, int now
inherit_dim = self.index_tracer._find_inherit_dim(
input_node, v, self.index_tracer.nodes_list[k]
input_node, v, self.index_tracer.node_list[k]
)
if k == _find_idx_by_name(node.name, self.index_tracer.nodes_list):
if k == _find_idx_by_name(node.name, self.index_tracer.node_list):
chunk_ratio = float(chunk_size) / node_shape[inherit_dim]
return chunk_ratio
for dim, source in enumerate(node_source):
@ -1412,8 +1297,6 @@ class ChunkRegionSearch(object):
self.node_list = list(gm.graph.nodes)
self.index_tracer = IndexTracer(gm)
self.index_tracer.trace_index()
self.flow_tracer = FlowTracer(gm)
self.flow_tracer.trace_flow()
self.memory_estimator = MemoryEstimator(self.index_tracer)
def _find_peak_node(self, mem_peak):
@ -1517,8 +1400,8 @@ class ChunkRegionSearch(object):
):
continue
# flow search
chunk_info = self.flow_tracer.flow_search(
start_idx, start_dim, end_idx, end_dim, self.index_tracer
chunk_info = self.index_tracer.flow_search(
start_idx, start_dim, end_idx, end_dim
)
if chunk_info is None:
continue