diff --git a/chunk_codegen.py b/chunk_codegen.py index e2786d5e2..838f53949 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -67,7 +67,7 @@ def _is_non_compute_node_except_placeholder_output(node): class IndexTracer(object): def __init__(self, gm) -> None: self.gm = gm - self.nodes_list = list(gm.graph.nodes) + self.node_list = list(gm.graph.nodes) self.idx_trace_list = self._init_idx_trace_list() self.idx_trace_equal = [] self.idx_view_list = [] @@ -75,7 +75,7 @@ class IndexTracer(object): def _init_idx_trace_list(self): idx_trace_list = [] - for n in self.nodes_list: + for n in self.node_list: if _get_node_shape(n) != None: cur_trace = { "idx": [None for _ in range(len(_get_node_shape(n)))], @@ -136,7 +136,7 @@ class IndexTracer(object): node_from_trace = self._find_trace_from_node(node_from) node_to_dim = self._transform_index(node_to, node_to_dim) node_to_trace = self._find_trace_from_node(node_to) - node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list) + node_from_idx = _find_idx_by_name(node_from.name, self.node_list) if init: node_to_trace["source"][node_to_dim] = {} # add dim to cur new source @@ -210,7 +210,7 @@ class IndexTracer(object): idx (list): idx of the node compute (list): computed idx of the node. """ - node_idx = _find_idx_by_name(node.name, self.nodes_list) + node_idx = _find_idx_by_name(node.name, self.node_list) node_dict = self.idx_trace_list[node_idx] return node_dict @@ -224,7 +224,7 @@ class IndexTracer(object): idx (list): idx of the node compute (list): computed idx of the node. """ - node_idx = _find_idx_by_name(node.name, self.nodes_list) + node_idx = _find_idx_by_name(node.name, self.node_list) node_dict = self.idx_trace_list[node_idx] return node_dict["source"] @@ -237,7 +237,7 @@ class IndexTracer(object): Returns: idx (list): idx of the node """ - node_idx = _find_idx_by_name(node.name, self.nodes_list) + node_idx = _find_idx_by_name(node.name, self.node_list) return self.idx_trace_list[node_idx]["idx"] def _find_compute_trace_from_node(self, node): @@ -249,7 +249,7 @@ class IndexTracer(object): Returns: compute (list): computed idx of the node. """ - node_idx = _find_idx_by_name(node.name, self.nodes_list) + node_idx = _find_idx_by_name(node.name, self.node_list) return self.idx_trace_list[node_idx]["compute"] def _assign_index_as_input(self, node, node_idx, input_node=None): @@ -262,7 +262,7 @@ class IndexTracer(object): """ if input_node == None: input_node = node.args[0] - input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list) + input_node_idx = _find_idx_by_name(input_node.name, self.node_list) input_node_idx_trace = self.idx_trace_list[input_node_idx]["idx"] new_idx_trace = copy.deepcopy(input_node_idx_trace) @@ -591,7 +591,7 @@ class IndexTracer(object): ] def trace_index(self): - for idx, node in enumerate(self.nodes_list): + for idx, node in enumerate(self.node_list): if node.op == "placeholder": self._assign_all_index(node, idx) elif node.op == "call_method": @@ -655,7 +655,7 @@ class IndexTracer(object): Returns: bool: True if check pass """ - start_node_idx = _find_idx_by_name(start_node.name, self.nodes_list) + start_node_idx = _find_idx_by_name(start_node.name, self.node_list) end_node_trace = self._find_trace_from_node(end_node) end_node_trace_source = end_node_trace["source"][end_dim] sorted_source = sorted( @@ -690,14 +690,14 @@ class IndexTracer(object): def get_node_chunk_dim(self, node_from, node_from_dim, node_to): node_from_source = self._find_source_trace_from_node(node_from) dim_source = node_from_source[node_from_dim] - node_to_idx = _find_idx_by_name(node_to.name, self.nodes_list) + node_to_idx = _find_idx_by_name(node_to.name, self.node_list) for k, v in dim_source.items(): if k == node_to_idx: return v return None def _find_inherit_dim(self, input_node, input_dim, node): - input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list) + input_node_idx = _find_idx_by_name(input_node.name, self.node_list) node_trace_source = self._find_source_trace_from_node(node) for node_dim in range(len(_get_node_shape(node))): if ( @@ -711,11 +711,11 @@ class IndexTracer(object): input_dim_after_node = {} for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): - inherit_dim = self._find_inherit_dim(input_node, v, self.nodes_list[k]) + inherit_dim = self._find_inherit_dim(input_node, v, self.node_list[k]) if inherit_dim: input_dim_after_node[k] = inherit_dim - for node in self.nodes_list[ + for node in self.node_list[ chunk_infos["region"][0] : chunk_infos["region"][1] + 1 ]: if _is_non_compute_node_except_placeholder(node): @@ -746,124 +746,11 @@ class IndexTracer(object): else: return True - -class FlowTracer(object): - def __init__(self, gm) -> None: - self.gm = gm - self.node_list = list(gm.graph.nodes) - self.flow_trace = {} - - def _add_trace(self, name): - self.flow_trace[name] = [] - - def _add_node(self, trace_name, node): - self.flow_trace[trace_name].append( - {"node": node, "inside_depend": [], "outside_depend": []} - ) - - def _add_inside_depend(self, flow_name, node, inside_depend_node): - for i in self.flow_trace[flow_name]: - if i["node"] == node: - i["inside_depend"].append(inside_depend_node) - return - raise RuntimeError("node not found") - - def _add_outside_depend( - self, flow_name, node, outside_depend_node, outside_depend_trace - ): - for i in self.flow_trace[flow_name]: - if i["node"] == node: - i["outside_depend"].append({outside_depend_trace: outside_depend_node}) - return - raise RuntimeError("node not found") - - def _init_trace(self): - for i in self.node_list: - if i.op == "placeholder": - self._add_trace(i.name) - self._add_node(i.name, i) - - def _find_flow_for_node(self, node): - if type(self.node_list[0]) != type(node): - return None - if _is_non_compute_node_except_placeholder(node): - return None - for name, trace in self.flow_trace.items(): - for i in trace: - if node == i["node"]: - return name - if any(i in node.name for i in ["ones_like"]): - self._add_trace(node.name) - self._add_node(node.name, node) - return node.name - raise RuntimeError("node not found") - - def _find_first_valid_flow(self, flow): - for i in flow: - if i is not None: - return i - raise RuntimeError("invalid flow") - - def find_node_flow(self, node): - for name, trace in self.flow_trace.items(): - for i in trace: - if node == i["node"]: - return name, i - raise RuntimeError("invalid node") - - def _get_flow_mix_node(self, node): - if _is_non_compute_node(node): - return None - _, node_trace = self.find_node_flow(node) - if len(node_trace["outside_depend"]) == 0: - return None - elif len(node_trace["outside_depend"]) > 1: - raise NotImplementedError - vars = list(node_trace["outside_depend"][0].values())[0] - return vars - - def _get_same_flow_node(self, node_list, node): - name, _ = self.find_node_flow(node) - result = [] - for i in self.flow_trace[name]: - if i["node"] in node_list: - result.append(i["node"]) - return result - - def trace_flow(self): - # init trace - self._init_trace() - - for node in self.node_list: - # skip if non compute node - if all( - type(arg) != type(node) or _is_non_compute_node_except_placeholder(arg) - for arg in node.args - ) or _is_non_compute_node(node): - continue - - node_input_flows = [self._find_flow_for_node(arg) for arg in node.args] - - node_domin_flow = self._find_first_valid_flow(node_input_flows) - self._add_node(node_domin_flow, node) - for node_input_flow, arg in zip(node_input_flows, node.args): - if node_input_flow is None: - continue - elif node_input_flow == node_domin_flow: - self._add_inside_depend(node_domin_flow, node, arg) - else: - self._add_outside_depend( - node_domin_flow, node, arg, node_input_flow - ) - return self.flow_trace - def _assgin_single_node_flow( self, arg_node, start_idx, end_idx, - inputs, - index_tracer, cur_node_dim, cur_node_compute, cur_node_source, @@ -871,7 +758,7 @@ class FlowTracer(object): all_node_info, next_node_list, ): - arg_idx = _find_idx_by_name(arg_node.name, index_tracer.nodes_list) + arg_idx = _find_idx_by_name(arg_node.name, self.node_list) # arg in chunk range or be inputs if not (start_idx <= arg_idx < end_idx): return True @@ -911,7 +798,7 @@ class FlowTracer(object): return True def flow_search( - self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer + self, start_idx, start_dim, end_idx, end_dim ): inputs, outputs = _find_chunk_compute_input_and_output_nodes( self.node_list[start_idx : end_idx + 1] @@ -920,7 +807,7 @@ class FlowTracer(object): if len(outputs) > 1: return None - cur_node_list = [index_tracer.nodes_list[end_idx]] # start from the last node + cur_node_list = [self.node_list[end_idx]] # start from the last node all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} while len(cur_node_list) > 0: @@ -930,12 +817,12 @@ class FlowTracer(object): # get cur node info cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] - cur_node_idx = _find_idx_by_name(cur_node.name, index_tracer.nodes_list) + cur_node_idx = _find_idx_by_name(cur_node.name, self.node_list) if cur_node_chunk_dim: - cur_node_compute = index_tracer._find_compute_trace_from_node( + cur_node_compute = self._find_compute_trace_from_node( cur_node ) - cur_node_source = index_tracer._find_source_trace_from_node( + cur_node_source = self._find_source_trace_from_node( cur_node ) else: @@ -953,8 +840,6 @@ class FlowTracer(object): arg, start_idx, end_idx, - inputs, - index_tracer, cur_node_chunk_dim, cur_node_compute, cur_node_source, @@ -970,7 +855,7 @@ class FlowTracer(object): for arg in arg_list: if not ( start_idx - <= _find_idx_by_name(arg.name, index_tracer.nodes_list) + <= _find_idx_by_name(arg.name, self.node_list) < end_idx ): continue @@ -1029,7 +914,7 @@ class FlowTracer(object): if node_info["chunk_dim"] is None: maybe_prepose_nodes.append(node) maybe_prepose_nodes.sort( - key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list), + key=lambda x: _find_idx_by_name(x.name, self.node_list), reverse=True, ) # from last node to first node prepose_nodes = [] @@ -1081,7 +966,7 @@ class FlowTracer(object): maybe_prepose_nodes.remove(n) # sort by index prepose_nodes.sort( - key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list) + key=lambda x: _find_idx_by_name(x.name, self.node_list) ) chunk_info["args"]["prepose_nodes"] = prepose_nodes @@ -1226,9 +1111,9 @@ class MemoryEstimator(object): for k, v in input_node_dim.items(): # TODO: inherit dim should be list too, int now inherit_dim = self.index_tracer._find_inherit_dim( - input_node, v, self.index_tracer.nodes_list[k] + input_node, v, self.index_tracer.node_list[k] ) - if k == _find_idx_by_name(node.name, self.index_tracer.nodes_list): + if k == _find_idx_by_name(node.name, self.index_tracer.node_list): chunk_ratio = float(chunk_size) / node_shape[inherit_dim] return chunk_ratio for dim, source in enumerate(node_source): @@ -1412,8 +1297,6 @@ class ChunkRegionSearch(object): self.node_list = list(gm.graph.nodes) self.index_tracer = IndexTracer(gm) self.index_tracer.trace_index() - self.flow_tracer = FlowTracer(gm) - self.flow_tracer.trace_flow() self.memory_estimator = MemoryEstimator(self.index_tracer) def _find_peak_node(self, mem_peak): @@ -1517,8 +1400,8 @@ class ChunkRegionSearch(object): ): continue # flow search - chunk_info = self.flow_tracer.flow_search( - start_idx, start_dim, end_idx, end_dim, self.index_tracer + chunk_info = self.index_tracer.flow_search( + start_idx, start_dim, end_idx, end_dim ) if chunk_info is None: continue