|
|
|
@ -64,212 +64,6 @@ def _is_non_compute_node_except_placeholder_output(node):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FlowTracer(object):
|
|
|
|
|
def __init__(self, gm) -> None:
|
|
|
|
|
self.gm = gm
|
|
|
|
|
self.node_list = list(gm.graph.nodes)
|
|
|
|
|
self.flow_trace = {}
|
|
|
|
|
|
|
|
|
|
def _add_trace(self, name):
|
|
|
|
|
self.flow_trace[name] = []
|
|
|
|
|
|
|
|
|
|
def _add_node(self, trace_name, node):
|
|
|
|
|
self.flow_trace[trace_name].append(
|
|
|
|
|
{"node": node, "inside_depend": [], "outside_depend": []}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _add_inside_depend(self, flow_name, node, inside_depend_node):
|
|
|
|
|
for i in self.flow_trace[flow_name]:
|
|
|
|
|
if i["node"] == node:
|
|
|
|
|
i["inside_depend"].append(inside_depend_node)
|
|
|
|
|
return
|
|
|
|
|
raise RuntimeError("node not found")
|
|
|
|
|
|
|
|
|
|
def _add_outside_depend(
|
|
|
|
|
self, flow_name, node, outside_depend_node, outside_depend_trace
|
|
|
|
|
):
|
|
|
|
|
for i in self.flow_trace[flow_name]:
|
|
|
|
|
if i["node"] == node:
|
|
|
|
|
i["outside_depend"].append({outside_depend_trace: outside_depend_node})
|
|
|
|
|
return
|
|
|
|
|
raise RuntimeError("node not found")
|
|
|
|
|
|
|
|
|
|
def _init_trace(self):
|
|
|
|
|
for i in self.node_list:
|
|
|
|
|
if i.op == "placeholder":
|
|
|
|
|
self._add_trace(i.name)
|
|
|
|
|
self._add_node(i.name, i)
|
|
|
|
|
|
|
|
|
|
def _find_flow_for_node(self, node):
|
|
|
|
|
if type(self.node_list[0]) != type(node):
|
|
|
|
|
return None
|
|
|
|
|
if _is_non_compute_node_except_placeholder(node):
|
|
|
|
|
return None
|
|
|
|
|
for name, trace in self.flow_trace.items():
|
|
|
|
|
for i in trace:
|
|
|
|
|
if node == i["node"]:
|
|
|
|
|
return name
|
|
|
|
|
if any(i in node.name for i in ["ones_like"]):
|
|
|
|
|
self._add_trace(node.name)
|
|
|
|
|
self._add_node(node.name, node)
|
|
|
|
|
return node.name
|
|
|
|
|
raise RuntimeError("node not found")
|
|
|
|
|
|
|
|
|
|
def _find_first_valid_flow(self, flow):
|
|
|
|
|
for i in flow:
|
|
|
|
|
if i is not None:
|
|
|
|
|
return i
|
|
|
|
|
raise RuntimeError("invalid flow")
|
|
|
|
|
|
|
|
|
|
def find_node_flow(self, node):
|
|
|
|
|
for name, trace in self.flow_trace.items():
|
|
|
|
|
for i in trace:
|
|
|
|
|
if node == i["node"]:
|
|
|
|
|
return name, i
|
|
|
|
|
raise RuntimeError("invalid node")
|
|
|
|
|
|
|
|
|
|
def _get_flow_mix_node(self, node):
|
|
|
|
|
if _is_non_compute_node(node):
|
|
|
|
|
return None
|
|
|
|
|
_, node_trace = self.find_node_flow(node)
|
|
|
|
|
if len(node_trace["outside_depend"]) == 0:
|
|
|
|
|
return None
|
|
|
|
|
elif len(node_trace["outside_depend"]) > 1:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
vars = list(node_trace["outside_depend"][0].values())[0]
|
|
|
|
|
return vars
|
|
|
|
|
|
|
|
|
|
def _get_same_flow_node(self, node_list, node):
|
|
|
|
|
name, _ = self.find_node_flow(node)
|
|
|
|
|
result = []
|
|
|
|
|
for i in self.flow_trace[name]:
|
|
|
|
|
if i["node"] in node_list:
|
|
|
|
|
result.append(i["node"])
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
def trace_flow(self):
|
|
|
|
|
# init trace
|
|
|
|
|
self._init_trace()
|
|
|
|
|
|
|
|
|
|
for node in self.node_list:
|
|
|
|
|
# skip if non compute node
|
|
|
|
|
if all(
|
|
|
|
|
type(arg) != type(node) or _is_non_compute_node_except_placeholder(arg)
|
|
|
|
|
for arg in node.args
|
|
|
|
|
) or _is_non_compute_node(node):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
node_input_flows = [self._find_flow_for_node(arg) for arg in node.args]
|
|
|
|
|
|
|
|
|
|
node_domin_flow = self._find_first_valid_flow(node_input_flows)
|
|
|
|
|
self._add_node(node_domin_flow, node)
|
|
|
|
|
for node_input_flow, arg in zip(node_input_flows, node.args):
|
|
|
|
|
if node_input_flow is None:
|
|
|
|
|
continue
|
|
|
|
|
elif node_input_flow == node_domin_flow:
|
|
|
|
|
self._add_inside_depend(node_domin_flow, node, arg)
|
|
|
|
|
else:
|
|
|
|
|
self._add_outside_depend(
|
|
|
|
|
node_domin_flow, node, arg, node_input_flow
|
|
|
|
|
)
|
|
|
|
|
return self.flow_trace
|
|
|
|
|
|
|
|
|
|
def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer):
|
|
|
|
|
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
|
|
|
|
|
self.node_list[start_idx : end_idx + 1]
|
|
|
|
|
)
|
|
|
|
|
chunk_info = {
|
|
|
|
|
"region": (start_idx, end_idx),
|
|
|
|
|
"inputs": inputs,
|
|
|
|
|
"inputs_non_chunk": [],
|
|
|
|
|
"inputs_dim": start_dim,
|
|
|
|
|
"outputs": outputs,
|
|
|
|
|
"outputs_dim": end_dim,
|
|
|
|
|
"args": {},
|
|
|
|
|
}
|
|
|
|
|
flow_block = False
|
|
|
|
|
|
|
|
|
|
# TODO don't allow multi outputs now
|
|
|
|
|
if len(outputs) > 1:
|
|
|
|
|
flow_block = True
|
|
|
|
|
return flow_block, chunk_info
|
|
|
|
|
|
|
|
|
|
for idx in range(start_idx, end_idx + 1):
|
|
|
|
|
node = self.node_list[idx]
|
|
|
|
|
mix_flow_node = self._get_flow_mix_node(node)
|
|
|
|
|
if mix_flow_node is None:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# if there is a flow mix, op must be in [mul, add, matmul]
|
|
|
|
|
# element-wise op requires dim to be equal in every dim
|
|
|
|
|
if any(n in node.name for n in ["mul", "add"]):
|
|
|
|
|
for i in node.args:
|
|
|
|
|
if type(i) == type(mix_flow_node) and i != mix_flow_node:
|
|
|
|
|
main_flow_var = i
|
|
|
|
|
# if mix flow is a broadcast in chunk dim,
|
|
|
|
|
# TODO: need to move that flow out of the chunk
|
|
|
|
|
mix_flow_node_dim = index_tracer.get_node_chunk_dim(
|
|
|
|
|
self.node_list[end_idx], end_dim, node
|
|
|
|
|
)
|
|
|
|
|
if mix_flow_node_dim is None:
|
|
|
|
|
flow_block = True
|
|
|
|
|
break
|
|
|
|
|
if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1:
|
|
|
|
|
flow_block = False
|
|
|
|
|
for i in self._get_same_flow_node(
|
|
|
|
|
chunk_info["inputs"], mix_flow_node
|
|
|
|
|
):
|
|
|
|
|
chunk_info["inputs"].remove(i)
|
|
|
|
|
# else, we need to chunk mix var as well
|
|
|
|
|
else:
|
|
|
|
|
# TODO chunk another value
|
|
|
|
|
flow_block = True
|
|
|
|
|
break
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError("%s not implemented" % node.name)
|
|
|
|
|
|
|
|
|
|
if flow_block:
|
|
|
|
|
flow_block = True
|
|
|
|
|
return flow_block, chunk_info
|
|
|
|
|
|
|
|
|
|
inputs_dim = []
|
|
|
|
|
remove_inputs = []
|
|
|
|
|
for input_node in chunk_info["inputs"]:
|
|
|
|
|
input_dict = {}
|
|
|
|
|
for user in input_node.users.keys():
|
|
|
|
|
if _is_non_compute_node(user):
|
|
|
|
|
continue
|
|
|
|
|
user_idx = _find_idx_by_name(user.name, self.node_list)
|
|
|
|
|
dim = None
|
|
|
|
|
if start_dim <= user_idx < end_idx:
|
|
|
|
|
dim = index_tracer.get_node_chunk_dim(
|
|
|
|
|
self.node_list[end_idx], end_dim, input_node
|
|
|
|
|
)
|
|
|
|
|
elif user_idx == end_idx:
|
|
|
|
|
dim = end_dim
|
|
|
|
|
# n has relation with chunk dim
|
|
|
|
|
if dim is not None and _get_node_shape(user)[dim] != 1:
|
|
|
|
|
input_dict[user_idx] = dim
|
|
|
|
|
if len(input_dict) == 0:
|
|
|
|
|
remove_inputs.append(input_node)
|
|
|
|
|
else:
|
|
|
|
|
inputs_dim.append(input_dict)
|
|
|
|
|
chunk_info["inputs_dim"] = inputs_dim
|
|
|
|
|
for i in remove_inputs:
|
|
|
|
|
if i in chunk_info["inputs"]:
|
|
|
|
|
chunk_info["inputs"].remove(i)
|
|
|
|
|
|
|
|
|
|
# we need to log input nodes to avoid deleteing them in the loop
|
|
|
|
|
non_chunk_inputs = _find_chunk_all_input_nodes(
|
|
|
|
|
self.node_list[start_idx : end_idx + 1]
|
|
|
|
|
)
|
|
|
|
|
for i in non_chunk_inputs:
|
|
|
|
|
if i not in chunk_info["inputs"]:
|
|
|
|
|
chunk_info["inputs_non_chunk"].append(i)
|
|
|
|
|
|
|
|
|
|
return flow_block, chunk_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IndexTracer(object):
|
|
|
|
|
def __init__(self, gm) -> None:
|
|
|
|
|
self.gm = gm
|
|
|
|
@ -932,6 +726,213 @@ class IndexTracer(object):
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FlowTracer(object):
|
|
|
|
|
def __init__(self, gm) -> None:
|
|
|
|
|
self.gm = gm
|
|
|
|
|
self.node_list = list(gm.graph.nodes)
|
|
|
|
|
self.flow_trace = {}
|
|
|
|
|
|
|
|
|
|
def _add_trace(self, name):
|
|
|
|
|
self.flow_trace[name] = []
|
|
|
|
|
|
|
|
|
|
def _add_node(self, trace_name, node):
|
|
|
|
|
self.flow_trace[trace_name].append(
|
|
|
|
|
{"node": node, "inside_depend": [], "outside_depend": []}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _add_inside_depend(self, flow_name, node, inside_depend_node):
|
|
|
|
|
for i in self.flow_trace[flow_name]:
|
|
|
|
|
if i["node"] == node:
|
|
|
|
|
i["inside_depend"].append(inside_depend_node)
|
|
|
|
|
return
|
|
|
|
|
raise RuntimeError("node not found")
|
|
|
|
|
|
|
|
|
|
def _add_outside_depend(
|
|
|
|
|
self, flow_name, node, outside_depend_node, outside_depend_trace
|
|
|
|
|
):
|
|
|
|
|
for i in self.flow_trace[flow_name]:
|
|
|
|
|
if i["node"] == node:
|
|
|
|
|
i["outside_depend"].append({outside_depend_trace: outside_depend_node})
|
|
|
|
|
return
|
|
|
|
|
raise RuntimeError("node not found")
|
|
|
|
|
|
|
|
|
|
def _init_trace(self):
|
|
|
|
|
for i in self.node_list:
|
|
|
|
|
if i.op == "placeholder":
|
|
|
|
|
self._add_trace(i.name)
|
|
|
|
|
self._add_node(i.name, i)
|
|
|
|
|
|
|
|
|
|
def _find_flow_for_node(self, node):
|
|
|
|
|
if type(self.node_list[0]) != type(node):
|
|
|
|
|
return None
|
|
|
|
|
if _is_non_compute_node_except_placeholder(node):
|
|
|
|
|
return None
|
|
|
|
|
for name, trace in self.flow_trace.items():
|
|
|
|
|
for i in trace:
|
|
|
|
|
if node == i["node"]:
|
|
|
|
|
return name
|
|
|
|
|
if any(i in node.name for i in ["ones_like"]):
|
|
|
|
|
self._add_trace(node.name)
|
|
|
|
|
self._add_node(node.name, node)
|
|
|
|
|
return node.name
|
|
|
|
|
raise RuntimeError("node not found")
|
|
|
|
|
|
|
|
|
|
def _find_first_valid_flow(self, flow):
|
|
|
|
|
for i in flow:
|
|
|
|
|
if i is not None:
|
|
|
|
|
return i
|
|
|
|
|
raise RuntimeError("invalid flow")
|
|
|
|
|
|
|
|
|
|
def find_node_flow(self, node):
|
|
|
|
|
for name, trace in self.flow_trace.items():
|
|
|
|
|
for i in trace:
|
|
|
|
|
if node == i["node"]:
|
|
|
|
|
return name, i
|
|
|
|
|
raise RuntimeError("invalid node")
|
|
|
|
|
|
|
|
|
|
def _get_flow_mix_node(self, node):
|
|
|
|
|
if _is_non_compute_node(node):
|
|
|
|
|
return None
|
|
|
|
|
_, node_trace = self.find_node_flow(node)
|
|
|
|
|
if len(node_trace["outside_depend"]) == 0:
|
|
|
|
|
return None
|
|
|
|
|
elif len(node_trace["outside_depend"]) > 1:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
vars = list(node_trace["outside_depend"][0].values())[0]
|
|
|
|
|
return vars
|
|
|
|
|
|
|
|
|
|
def _get_same_flow_node(self, node_list, node):
|
|
|
|
|
name, _ = self.find_node_flow(node)
|
|
|
|
|
result = []
|
|
|
|
|
for i in self.flow_trace[name]:
|
|
|
|
|
if i["node"] in node_list:
|
|
|
|
|
result.append(i["node"])
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
def trace_flow(self):
|
|
|
|
|
# init trace
|
|
|
|
|
self._init_trace()
|
|
|
|
|
|
|
|
|
|
for node in self.node_list:
|
|
|
|
|
# skip if non compute node
|
|
|
|
|
if all(
|
|
|
|
|
type(arg) != type(node) or _is_non_compute_node_except_placeholder(arg)
|
|
|
|
|
for arg in node.args
|
|
|
|
|
) or _is_non_compute_node(node):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
node_input_flows = [self._find_flow_for_node(arg) for arg in node.args]
|
|
|
|
|
|
|
|
|
|
node_domin_flow = self._find_first_valid_flow(node_input_flows)
|
|
|
|
|
self._add_node(node_domin_flow, node)
|
|
|
|
|
for node_input_flow, arg in zip(node_input_flows, node.args):
|
|
|
|
|
if node_input_flow is None:
|
|
|
|
|
continue
|
|
|
|
|
elif node_input_flow == node_domin_flow:
|
|
|
|
|
self._add_inside_depend(node_domin_flow, node, arg)
|
|
|
|
|
else:
|
|
|
|
|
self._add_outside_depend(
|
|
|
|
|
node_domin_flow, node, arg, node_input_flow
|
|
|
|
|
)
|
|
|
|
|
return self.flow_trace
|
|
|
|
|
|
|
|
|
|
def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer):
|
|
|
|
|
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
|
|
|
|
|
self.node_list[start_idx : end_idx + 1]
|
|
|
|
|
)
|
|
|
|
|
chunk_info = {
|
|
|
|
|
"region": (start_idx, end_idx),
|
|
|
|
|
"inputs": inputs,
|
|
|
|
|
"inputs_non_chunk": [],
|
|
|
|
|
"inputs_dim": start_dim,
|
|
|
|
|
"outputs": outputs,
|
|
|
|
|
"outputs_dim": end_dim,
|
|
|
|
|
"args": {},
|
|
|
|
|
}
|
|
|
|
|
flow_block = False
|
|
|
|
|
|
|
|
|
|
# TODO don't allow multi outputs now
|
|
|
|
|
if len(outputs) > 1:
|
|
|
|
|
flow_block = True
|
|
|
|
|
return flow_block, chunk_info
|
|
|
|
|
|
|
|
|
|
for idx in range(start_idx, end_idx + 1):
|
|
|
|
|
node = self.node_list[idx]
|
|
|
|
|
mix_flow_node = self._get_flow_mix_node(node)
|
|
|
|
|
if mix_flow_node is None:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# if there is a flow mix, op must be in [mul, add, matmul]
|
|
|
|
|
# element-wise op requires dim to be equal in every dim
|
|
|
|
|
if any(n in node.name for n in ["mul", "add"]):
|
|
|
|
|
for i in node.args:
|
|
|
|
|
if type(i) == type(mix_flow_node) and i != mix_flow_node:
|
|
|
|
|
main_flow_var = i
|
|
|
|
|
# if mix flow is a broadcast in chunk dim,
|
|
|
|
|
# TODO: need to move that flow out of the chunk
|
|
|
|
|
mix_flow_node_dim = index_tracer.get_node_chunk_dim(
|
|
|
|
|
self.node_list[end_idx], end_dim, node
|
|
|
|
|
)
|
|
|
|
|
if mix_flow_node_dim is None:
|
|
|
|
|
flow_block = True
|
|
|
|
|
break
|
|
|
|
|
if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1:
|
|
|
|
|
flow_block = False
|
|
|
|
|
for i in self._get_same_flow_node(
|
|
|
|
|
chunk_info["inputs"], mix_flow_node
|
|
|
|
|
):
|
|
|
|
|
chunk_info["inputs"].remove(i)
|
|
|
|
|
# else, we need to chunk mix var as well
|
|
|
|
|
else:
|
|
|
|
|
# TODO chunk another value
|
|
|
|
|
flow_block = True
|
|
|
|
|
break
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError("%s not implemented" % node.name)
|
|
|
|
|
|
|
|
|
|
if flow_block:
|
|
|
|
|
flow_block = True
|
|
|
|
|
return flow_block, chunk_info
|
|
|
|
|
|
|
|
|
|
inputs_dim = []
|
|
|
|
|
remove_inputs = []
|
|
|
|
|
for input_node in chunk_info["inputs"]:
|
|
|
|
|
input_dict = {}
|
|
|
|
|
for user in input_node.users.keys():
|
|
|
|
|
if _is_non_compute_node(user):
|
|
|
|
|
continue
|
|
|
|
|
user_idx = _find_idx_by_name(user.name, self.node_list)
|
|
|
|
|
dim = None
|
|
|
|
|
if start_dim <= user_idx < end_idx:
|
|
|
|
|
dim = index_tracer.get_node_chunk_dim(
|
|
|
|
|
self.node_list[end_idx], end_dim, input_node
|
|
|
|
|
)
|
|
|
|
|
elif user_idx == end_idx:
|
|
|
|
|
dim = end_dim
|
|
|
|
|
# n has relation with chunk dim
|
|
|
|
|
if dim is not None and _get_node_shape(user)[dim] != 1:
|
|
|
|
|
input_dict[user_idx] = dim
|
|
|
|
|
if len(input_dict) == 0:
|
|
|
|
|
remove_inputs.append(input_node)
|
|
|
|
|
else:
|
|
|
|
|
inputs_dim.append(input_dict)
|
|
|
|
|
chunk_info["inputs_dim"] = inputs_dim
|
|
|
|
|
for i in remove_inputs:
|
|
|
|
|
if i in chunk_info["inputs"]:
|
|
|
|
|
chunk_info["inputs"].remove(i)
|
|
|
|
|
|
|
|
|
|
# we need to log input nodes to avoid deleteing them in the loop
|
|
|
|
|
non_chunk_inputs = _find_chunk_all_input_nodes(
|
|
|
|
|
self.node_list[start_idx : end_idx + 1]
|
|
|
|
|
)
|
|
|
|
|
for i in non_chunk_inputs:
|
|
|
|
|
if i not in chunk_info["inputs"]:
|
|
|
|
|
chunk_info["inputs_non_chunk"].append(i)
|
|
|
|
|
|
|
|
|
|
return flow_block, chunk_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MemoryEstimator(object):
|
|
|
|
|
def __init__(self, index_tracer: IndexTracer) -> None:
|
|
|
|
|
self.index_tracer = index_tracer
|
|
|
|
|