From 51ef8384c153f46dcbb74c26eec523ad7cd0d51c Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 23 Dec 2022 17:25:36 +0800 Subject: [PATCH] finish node reorder --- chunk_codegen.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 4b3b04d93..9623a9d9b 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1238,7 +1238,7 @@ class MemoryEstimator(object): def estimate_chunk_inference_mem( self, - gm: torch.fx.GraphModule, + node_list, chunk_infos=None, ): act_memory = 0.0 @@ -1247,7 +1247,6 @@ class MemoryEstimator(object): active_node_list = [] active_node_list_log = [] not_contiguous_list = [] - node_list = list(gm.graph.nodes) user_to_last_uses = self._get_last_usr(node_list) user_to_last_uses_no_free_var = self._get_last_usr(node_list) _delete_free_var_from_last_use(user_to_last_uses_no_free_var) @@ -1281,7 +1280,6 @@ class MemoryEstimator(object): ) / (1024**2) # determine chunk ratio for current node - # TODO: adapt to prepose node memory if chunk_within: chunk_ratio = self._get_chunk_ratio( node, @@ -1371,10 +1369,7 @@ class MemoryEstimator(object): class ChunkRegionSearch(object): def __init__(self, gm) -> None: self.gm = gm - self.node_list = list(gm.graph.nodes) - self.index_tracer = IndexTracer( - self.node_list - ) # node list shared in index tracer + self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer.trace_index() self.memory_estimator = MemoryEstimator(self.index_tracer) @@ -1385,7 +1380,7 @@ class ChunkRegionSearch(object): def _get_free_var(self): free_var_idx = [] - for idx, n in enumerate(self.node_list): + for idx, n in enumerate(self.index_tracer.node_list): if n.op == "placeholder": free_var_idx.append(idx) return free_var_idx @@ -1455,13 +1450,13 @@ class ChunkRegionSearch(object): def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): start_traces = input_trace[start_idx] end_trace = output_trace[end_idx] - end_node = self.node_list[end_idx] + end_node = self.index_tracer.node_list[end_idx] chunk_infos = [] - for end_dim, end_trace_idx in enumerate(end_trace["idx"]): + for end_dim, _ in enumerate(end_trace["idx"]): if len(start_traces) > 1: continue for start_node, start_trace in start_traces.items(): - for start_dim, start_trace_idx in enumerate(start_trace["idx"]): + for start_dim, _ in enumerate(start_trace["idx"]): # dim size cannot be 1 if ( _get_node_shape(end_node)[end_dim] == 1 @@ -1494,7 +1489,7 @@ class ChunkRegionSearch(object): possible_chunk_region = [] output_trace = copy.deepcopy(self.index_tracer.idx_trace_list) input_trace = [] # trace of a node's input nodes - for _, n in enumerate(self.node_list): + for _, n in enumerate(self.index_tracer.node_list): cur_trace = {} for arg in n.args: if type(arg) == type(n) and not _is_non_compute_node_except_placeholder( @@ -1507,8 +1502,8 @@ class ChunkRegionSearch(object): for end_idx in range(peak_node, max_chunk_region[1] + 1): # skip non compute nodes if _is_non_compute_node( - self.node_list[start_idx] - ) or _is_non_compute_node(self.node_list[end_idx]): + self.index_tracer.node_list[start_idx] + ) or _is_non_compute_node(self.index_tracer.node_list[end_idx]): continue # select free dim @@ -1577,7 +1572,9 @@ class ChunkRegionSearch(object): init_mem_peak, _, active_node, - ) = self.memory_estimator.estimate_chunk_inference_mem(self.gm) + ) = self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list + ) mem_peak = init_mem_peak while True: @@ -1590,7 +1587,9 @@ class ChunkRegionSearch(object): mem_peak, _, active_node, - ) = self.memory_estimator.estimate_chunk_inference_mem(self.gm, chunk_infos) + ) = self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list, chunk_infos + ) if self._stop_search(init_mem_peak, mem_peak): break return chunk_infos