From 7330d907459a220ebedaeafbbcc7c3cff3c8b1c4 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Sun, 4 Dec 2022 17:05:28 +0800 Subject: [PATCH] add possible region search --- chunk_codegen.py | 116 ++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 109 insertions(+), 7 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 77aca8deb..ba83f7fec 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -356,7 +356,17 @@ class NodeIndexTracer(object): "idx_to": [new_trace[i] for i in dim_to], "dim_to": dim_to} self.idx_view_list.append(view_dict) - + + def _merge_equal_idx(self): + idx_equal = copy.deepcopy(self.idx_trace_equal) + idx_equal.reverse() + for idx in idx_equal: + merge_to = min(idx) + merge_from = max(idx) + for trace in self.idx_trace_list: + if merge_from in trace['idx']: + trace['idx'] = [merge_to if i == merge_from else i for i in trace['idx']] + def trace_node_idx(self): for idx, node in enumerate(self.nodes_list): if node.op == 'placeholder': @@ -396,6 +406,7 @@ class NodeIndexTracer(object): continue else: raise NotImplementedError(node.op, "op not implemented yet!") + self._merge_equal_idx() class MemoryEstimator(object): @@ -433,6 +444,8 @@ class MemoryEstimator(object): for i in range(len(out_node)): if out_node[i][0] > 0: delete_node.append(out_node[i][1][0]) + elif nodes_to_delete[i].op == 'placeholder': + delete_node.append(nodes_to_delete[i].name) return delete_size, delete_node def _get_delete_node_size(self, user, user_to_last_uses): @@ -516,8 +529,9 @@ class MemoryEstimator(object): active_node_list_log = [] not_contiguous_list = [] node_list = list(gm.graph.nodes) - user_to_last_uses = self._get_last_usr(list(gm.graph.nodes)) - _delete_free_var_from_last_use(user_to_last_uses) + user_to_last_uses = self._get_last_usr(node_list) + user_to_last_uses_no_free_var = self._get_last_usr(node_list) + _delete_free_var_from_last_use(user_to_last_uses_no_free_var) use_chunk = all(i is not None for i in [start_nodes, end_nodes, chunk_dims, chunk_sizes]) chunk_within = False @@ -535,6 +549,7 @@ class MemoryEstimator(object): if node.op == 'placeholder': act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024 ** 2) act_memory_peak_log.append(act_memory) + active_node_list.append(node.name) # skip output elif node.op == 'output': continue @@ -549,10 +564,10 @@ class MemoryEstimator(object): act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2) if chunk_within: act_memory -= self._get_chunk_delete_node_size( - node, user_to_last_uses, chunk_ratio, node_list, + node, user_to_last_uses_no_free_var, chunk_ratio, node_list, start_nodes[chunk_region_idx], end_nodes[chunk_region_idx]) / (1024 ** 2) else: - act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) + act_memory -= self._get_delete_node_size(node, user_to_last_uses_no_free_var) / (1024 ** 2) # log active node self._add_active_node(node, active_node_list) @@ -572,8 +587,92 @@ class MemoryEstimator(object): self._print_mem_log(act_memory_peak_log, node_list, "peak") self._print_mem_log(act_memory_after_node_log, node_list, "after") - param_memory = parameter_size(gm) - return act_memory + param_memory, param_memory + # param_memory = parameter_size(gm) + # all_memory = act_memory + param_memory + return act_memory_peak_log, act_memory_after_node_log, active_node_list_log + + +class ChunkRegionSearch(object): + def __init__(self, gm) -> None: + self.gm = gm + self.node_list = list(gm.graph.nodes) + self.memory_estimator = MemoryEstimator() + self.index_tracer = NodeIndexTracer(gm) + self.index_tracer.trace_node_idx() + + def _find_peak_node(self, mem_peak): + max_value = max(mem_peak) + max_idx = [mem_peak.index(max_value)] + return max_idx + + def _get_free_var(self): + free_var_idx = [] + for idx, n in enumerate(self.node_list): + if n.op == 'placeholder': + free_var_idx.append(idx) + return free_var_idx + + def _get_min_free_var(self, active_node_list, free_vars): + min_len = 999 + for idx, n in enumerate(active_node_list): + if idx in free_vars: + continue + if len(n) < min_len: + min_len = len(n) + return min_len + + def _search_max_chunk_region(self, active_node, peak_node): + free_vars = self._get_free_var() + min_var = self._get_min_free_var(active_node, free_vars) + + # from peak_node to free_var + chunk_region_start = None + for i in range(peak_node, -1, -1): + if len(active_node[i]) == min_var: + chunk_region_start = i + 1 + break + if i in free_vars or i == 0: + raise RuntimeError() + # from peak_node to len-2 + chunk_region_end = None + for i in range(peak_node, len(active_node) - 1): + if len(active_node[i]) == min_var: + chunk_region_end = i - 1 + break + if i in free_vars or i == 0: + raise RuntimeError() + return chunk_region_start, chunk_region_end + + def _search_possible_chunk_regions(self, max_chunk_region, peak_node): + possible_chunk_region = [] + for before_idx in range(max_chunk_region[0], peak_node): + for after_idx in range(peak_node, max_chunk_region[1]): + # skip non compute nodes + if any(op in ['placeholder', 'get_attr', 'output'] for op in + [self.node_list[before_idx].op, self.node_list[after_idx].op]): + continue + if any(any(i in name for i in ['getitem', 'getattr']) for name in + [self.node_list[before_idx].name, self.node_list[after_idx].name]): + continue + + # select free dim + before_trace = self.index_tracer.idx_trace_list[before_idx] + after_trace = self.index_tracer.idx_trace_list[after_idx] + free_dim = [] + for i in range(min(len(before_trace['idx']), len(after_trace['idx']))): + if (before_trace['idx'][i] == after_trace['idx'][i] and + before_trace['idx'][i] not in before_trace['compute'] and + after_trace['idx'][i] not in after_trace['compute']): + free_dim.append(i) + possible_chunk_region.append({'region': (before_idx, after_idx), 'dim': free_dim}) + return possible_chunk_region + + def search_region(self): + mem_peak, mem_after, active_node = self.memory_estimator.estimate_chunk_inference_mem(self.gm) + peak_nodes = self._find_peak_node(mem_peak) + for idx, peak_node in enumerate(peak_nodes): + max_chunk_region = self._search_max_chunk_region(active_node, peak_node) + possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node) def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): @@ -696,6 +795,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v node_index_tracer = NodeIndexTracer(meta_graph) node_index_tracer.trace_node_idx() + + chunk_region_search = ChunkRegionSearch(meta_graph) + chunk_region_search.search_region() # find the input and output var names for each offload region for idx, (start, end) in enumerate(chunk_regions):