From d734529a390087f1366b7573410eca5775735b14 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 21 Dec 2022 15:00:24 +0800 Subject: [PATCH] move flow tracer --- chunk_codegen.py | 413 ++++++++++++++++++++++++----------------------- 1 file changed, 207 insertions(+), 206 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 77c28fd32..2c1c09ae5 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -64,212 +64,6 @@ def _is_non_compute_node_except_placeholder_output(node): return False -class FlowTracer(object): - def __init__(self, gm) -> None: - self.gm = gm - self.node_list = list(gm.graph.nodes) - self.flow_trace = {} - - def _add_trace(self, name): - self.flow_trace[name] = [] - - def _add_node(self, trace_name, node): - self.flow_trace[trace_name].append( - {"node": node, "inside_depend": [], "outside_depend": []} - ) - - def _add_inside_depend(self, flow_name, node, inside_depend_node): - for i in self.flow_trace[flow_name]: - if i["node"] == node: - i["inside_depend"].append(inside_depend_node) - return - raise RuntimeError("node not found") - - def _add_outside_depend( - self, flow_name, node, outside_depend_node, outside_depend_trace - ): - for i in self.flow_trace[flow_name]: - if i["node"] == node: - i["outside_depend"].append({outside_depend_trace: outside_depend_node}) - return - raise RuntimeError("node not found") - - def _init_trace(self): - for i in self.node_list: - if i.op == "placeholder": - self._add_trace(i.name) - self._add_node(i.name, i) - - def _find_flow_for_node(self, node): - if type(self.node_list[0]) != type(node): - return None - if _is_non_compute_node_except_placeholder(node): - return None - for name, trace in self.flow_trace.items(): - for i in trace: - if node == i["node"]: - return name - if any(i in node.name for i in ["ones_like"]): - self._add_trace(node.name) - self._add_node(node.name, node) - return node.name - raise RuntimeError("node not found") - - def _find_first_valid_flow(self, flow): - for i in flow: - if i is not None: - return i - raise RuntimeError("invalid flow") - - def find_node_flow(self, node): - for name, trace in self.flow_trace.items(): - for i in trace: - if node == i["node"]: - return name, i - raise RuntimeError("invalid node") - - def _get_flow_mix_node(self, node): - if _is_non_compute_node(node): - return None - _, node_trace = self.find_node_flow(node) - if len(node_trace["outside_depend"]) == 0: - return None - elif len(node_trace["outside_depend"]) > 1: - raise NotImplementedError - vars = list(node_trace["outside_depend"][0].values())[0] - return vars - - def _get_same_flow_node(self, node_list, node): - name, _ = self.find_node_flow(node) - result = [] - for i in self.flow_trace[name]: - if i["node"] in node_list: - result.append(i["node"]) - return result - - def trace_flow(self): - # init trace - self._init_trace() - - for node in self.node_list: - # skip if non compute node - if all( - type(arg) != type(node) or _is_non_compute_node_except_placeholder(arg) - for arg in node.args - ) or _is_non_compute_node(node): - continue - - node_input_flows = [self._find_flow_for_node(arg) for arg in node.args] - - node_domin_flow = self._find_first_valid_flow(node_input_flows) - self._add_node(node_domin_flow, node) - for node_input_flow, arg in zip(node_input_flows, node.args): - if node_input_flow is None: - continue - elif node_input_flow == node_domin_flow: - self._add_inside_depend(node_domin_flow, node, arg) - else: - self._add_outside_depend( - node_domin_flow, node, arg, node_input_flow - ) - return self.flow_trace - - def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer): - inputs, outputs = _find_chunk_compute_input_and_output_nodes( - self.node_list[start_idx : end_idx + 1] - ) - chunk_info = { - "region": (start_idx, end_idx), - "inputs": inputs, - "inputs_non_chunk": [], - "inputs_dim": start_dim, - "outputs": outputs, - "outputs_dim": end_dim, - "args": {}, - } - flow_block = False - - # TODO don't allow multi outputs now - if len(outputs) > 1: - flow_block = True - return flow_block, chunk_info - - for idx in range(start_idx, end_idx + 1): - node = self.node_list[idx] - mix_flow_node = self._get_flow_mix_node(node) - if mix_flow_node is None: - continue - - # if there is a flow mix, op must be in [mul, add, matmul] - # element-wise op requires dim to be equal in every dim - if any(n in node.name for n in ["mul", "add"]): - for i in node.args: - if type(i) == type(mix_flow_node) and i != mix_flow_node: - main_flow_var = i - # if mix flow is a broadcast in chunk dim, - # TODO: need to move that flow out of the chunk - mix_flow_node_dim = index_tracer.get_node_chunk_dim( - self.node_list[end_idx], end_dim, node - ) - if mix_flow_node_dim is None: - flow_block = True - break - if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1: - flow_block = False - for i in self._get_same_flow_node( - chunk_info["inputs"], mix_flow_node - ): - chunk_info["inputs"].remove(i) - # else, we need to chunk mix var as well - else: - # TODO chunk another value - flow_block = True - break - else: - raise NotImplementedError("%s not implemented" % node.name) - - if flow_block: - flow_block = True - return flow_block, chunk_info - - inputs_dim = [] - remove_inputs = [] - for input_node in chunk_info["inputs"]: - input_dict = {} - for user in input_node.users.keys(): - if _is_non_compute_node(user): - continue - user_idx = _find_idx_by_name(user.name, self.node_list) - dim = None - if start_dim <= user_idx < end_idx: - dim = index_tracer.get_node_chunk_dim( - self.node_list[end_idx], end_dim, input_node - ) - elif user_idx == end_idx: - dim = end_dim - # n has relation with chunk dim - if dim is not None and _get_node_shape(user)[dim] != 1: - input_dict[user_idx] = dim - if len(input_dict) == 0: - remove_inputs.append(input_node) - else: - inputs_dim.append(input_dict) - chunk_info["inputs_dim"] = inputs_dim - for i in remove_inputs: - if i in chunk_info["inputs"]: - chunk_info["inputs"].remove(i) - - # we need to log input nodes to avoid deleteing them in the loop - non_chunk_inputs = _find_chunk_all_input_nodes( - self.node_list[start_idx : end_idx + 1] - ) - for i in non_chunk_inputs: - if i not in chunk_info["inputs"]: - chunk_info["inputs_non_chunk"].append(i) - - return flow_block, chunk_info - - class IndexTracer(object): def __init__(self, gm) -> None: self.gm = gm @@ -932,6 +726,213 @@ class IndexTracer(object): return True + +class FlowTracer(object): + def __init__(self, gm) -> None: + self.gm = gm + self.node_list = list(gm.graph.nodes) + self.flow_trace = {} + + def _add_trace(self, name): + self.flow_trace[name] = [] + + def _add_node(self, trace_name, node): + self.flow_trace[trace_name].append( + {"node": node, "inside_depend": [], "outside_depend": []} + ) + + def _add_inside_depend(self, flow_name, node, inside_depend_node): + for i in self.flow_trace[flow_name]: + if i["node"] == node: + i["inside_depend"].append(inside_depend_node) + return + raise RuntimeError("node not found") + + def _add_outside_depend( + self, flow_name, node, outside_depend_node, outside_depend_trace + ): + for i in self.flow_trace[flow_name]: + if i["node"] == node: + i["outside_depend"].append({outside_depend_trace: outside_depend_node}) + return + raise RuntimeError("node not found") + + def _init_trace(self): + for i in self.node_list: + if i.op == "placeholder": + self._add_trace(i.name) + self._add_node(i.name, i) + + def _find_flow_for_node(self, node): + if type(self.node_list[0]) != type(node): + return None + if _is_non_compute_node_except_placeholder(node): + return None + for name, trace in self.flow_trace.items(): + for i in trace: + if node == i["node"]: + return name + if any(i in node.name for i in ["ones_like"]): + self._add_trace(node.name) + self._add_node(node.name, node) + return node.name + raise RuntimeError("node not found") + + def _find_first_valid_flow(self, flow): + for i in flow: + if i is not None: + return i + raise RuntimeError("invalid flow") + + def find_node_flow(self, node): + for name, trace in self.flow_trace.items(): + for i in trace: + if node == i["node"]: + return name, i + raise RuntimeError("invalid node") + + def _get_flow_mix_node(self, node): + if _is_non_compute_node(node): + return None + _, node_trace = self.find_node_flow(node) + if len(node_trace["outside_depend"]) == 0: + return None + elif len(node_trace["outside_depend"]) > 1: + raise NotImplementedError + vars = list(node_trace["outside_depend"][0].values())[0] + return vars + + def _get_same_flow_node(self, node_list, node): + name, _ = self.find_node_flow(node) + result = [] + for i in self.flow_trace[name]: + if i["node"] in node_list: + result.append(i["node"]) + return result + + def trace_flow(self): + # init trace + self._init_trace() + + for node in self.node_list: + # skip if non compute node + if all( + type(arg) != type(node) or _is_non_compute_node_except_placeholder(arg) + for arg in node.args + ) or _is_non_compute_node(node): + continue + + node_input_flows = [self._find_flow_for_node(arg) for arg in node.args] + + node_domin_flow = self._find_first_valid_flow(node_input_flows) + self._add_node(node_domin_flow, node) + for node_input_flow, arg in zip(node_input_flows, node.args): + if node_input_flow is None: + continue + elif node_input_flow == node_domin_flow: + self._add_inside_depend(node_domin_flow, node, arg) + else: + self._add_outside_depend( + node_domin_flow, node, arg, node_input_flow + ) + return self.flow_trace + + def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer): + inputs, outputs = _find_chunk_compute_input_and_output_nodes( + self.node_list[start_idx : end_idx + 1] + ) + chunk_info = { + "region": (start_idx, end_idx), + "inputs": inputs, + "inputs_non_chunk": [], + "inputs_dim": start_dim, + "outputs": outputs, + "outputs_dim": end_dim, + "args": {}, + } + flow_block = False + + # TODO don't allow multi outputs now + if len(outputs) > 1: + flow_block = True + return flow_block, chunk_info + + for idx in range(start_idx, end_idx + 1): + node = self.node_list[idx] + mix_flow_node = self._get_flow_mix_node(node) + if mix_flow_node is None: + continue + + # if there is a flow mix, op must be in [mul, add, matmul] + # element-wise op requires dim to be equal in every dim + if any(n in node.name for n in ["mul", "add"]): + for i in node.args: + if type(i) == type(mix_flow_node) and i != mix_flow_node: + main_flow_var = i + # if mix flow is a broadcast in chunk dim, + # TODO: need to move that flow out of the chunk + mix_flow_node_dim = index_tracer.get_node_chunk_dim( + self.node_list[end_idx], end_dim, node + ) + if mix_flow_node_dim is None: + flow_block = True + break + if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1: + flow_block = False + for i in self._get_same_flow_node( + chunk_info["inputs"], mix_flow_node + ): + chunk_info["inputs"].remove(i) + # else, we need to chunk mix var as well + else: + # TODO chunk another value + flow_block = True + break + else: + raise NotImplementedError("%s not implemented" % node.name) + + if flow_block: + flow_block = True + return flow_block, chunk_info + + inputs_dim = [] + remove_inputs = [] + for input_node in chunk_info["inputs"]: + input_dict = {} + for user in input_node.users.keys(): + if _is_non_compute_node(user): + continue + user_idx = _find_idx_by_name(user.name, self.node_list) + dim = None + if start_dim <= user_idx < end_idx: + dim = index_tracer.get_node_chunk_dim( + self.node_list[end_idx], end_dim, input_node + ) + elif user_idx == end_idx: + dim = end_dim + # n has relation with chunk dim + if dim is not None and _get_node_shape(user)[dim] != 1: + input_dict[user_idx] = dim + if len(input_dict) == 0: + remove_inputs.append(input_node) + else: + inputs_dim.append(input_dict) + chunk_info["inputs_dim"] = inputs_dim + for i in remove_inputs: + if i in chunk_info["inputs"]: + chunk_info["inputs"].remove(i) + + # we need to log input nodes to avoid deleteing them in the loop + non_chunk_inputs = _find_chunk_all_input_nodes( + self.node_list[start_idx : end_idx + 1] + ) + for i in non_chunk_inputs: + if i not in chunk_info["inputs"]: + chunk_info["inputs_non_chunk"].append(i) + + return flow_block, chunk_info + + class MemoryEstimator(object): def __init__(self, index_tracer: IndexTracer) -> None: self.index_tracer = index_tracer