From ded1005667402ee9458afa53852ce2018b1ccb10 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 21 Dec 2022 15:03:08 +0800 Subject: [PATCH] format code --- chunk_codegen.py | 184 +++++++++++++++++++++++++++++++---------------- 1 file changed, 122 insertions(+), 62 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 3ba082ceb..eb16361c0 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -144,7 +144,9 @@ class IndexTracer(object): node_to_trace["source"][node_to_dim][node_from_idx] = [node_from_dim] else: if node_from_dim not in node_to_trace["source"][node_to_dim][node_from_idx]: - node_to_trace["source"][node_to_dim][node_from_idx].append(node_from_dim) + node_to_trace["source"][node_to_dim][node_from_idx].append( + node_from_dim + ) # update inputs source node_to_trace["source"][node_to_dim].update( node_from_trace["source"][node_from_dim] @@ -745,7 +747,6 @@ class IndexTracer(object): return True - class FlowTracer(object): def __init__(self, gm) -> None: self.gm = gm @@ -856,7 +857,9 @@ class FlowTracer(object): ) return self.flow_trace - def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer): + def _detect_flow( + self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer + ): inputs, outputs = _find_chunk_compute_input_and_output_nodes( self.node_list[start_idx : end_idx + 1] ) @@ -945,8 +948,10 @@ class FlowTracer(object): for i in remove_inputs: if i in chunk_info["inputs"]: chunk_info["inputs"].remove(i) - - duplicate_result, duplicate_dim = index_tracer.check_index_duplicate(chunk_info, return_dim=True) + + duplicate_result, duplicate_dim = index_tracer.check_index_duplicate( + chunk_info, return_dim=True + ) # we need to log input nodes to avoid deleteing them in the loop non_chunk_inputs = _find_chunk_all_input_nodes( @@ -958,15 +963,25 @@ class FlowTracer(object): return flow_block, chunk_info - def _assgin_single_node_flow(self, arg_node, start_idx, end_idx, - inputs, index_tracer, cur_node_dim, - cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info, - next_node_list): + def _assgin_single_node_flow( + self, + arg_node, + start_idx, + end_idx, + inputs, + index_tracer, + cur_node_dim, + cur_node_compute, + cur_node_source, + cur_node_fix_dim, + all_node_info, + next_node_list, + ): arg_idx = _find_idx_by_name(arg_node.name, index_tracer.nodes_list) # arg in chunk range or be inputs if not (start_idx <= arg_idx < end_idx): return True - + # find arg dim if cur_node_dim is not None: # dim is computed @@ -978,7 +993,7 @@ class FlowTracer(object): arg_dim = cur_node_source[cur_node_dim][arg_idx][0] else: arg_dim = None - + # get fix dim arg_fix_dim = [] if cur_node_dim is not None: @@ -986,44 +1001,52 @@ class FlowTracer(object): fix_dim_source = cur_node_source[i] if arg_idx in fix_dim_source: arg_fix_dim.append(fix_dim_source[arg_idx][0]) - + # if already in node_info, arg dim must be same if arg_node in all_node_info: if all_node_info[arg_node] != arg_dim: return False - all_node_info[arg_node]['fix_dim'] = list(set(all_node_info[arg_node]['fix_dim'] + arg_fix_dim)) + all_node_info[arg_node]["fix_dim"] = list( + set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim) + ) # else add it to list else: - all_node_info[arg_node] = {'chunk_dim': arg_dim, 'fix_dim': arg_fix_dim} - + all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim} + next_node_list.append(arg_node) return True - - def flow_search(self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer): + + def flow_search( + self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer + ): inputs, outputs = _find_chunk_compute_input_and_output_nodes( self.node_list[start_idx : end_idx + 1] ) # only single ouput if len(outputs) > 1: return None - + cur_node_list = [index_tracer.nodes_list[end_idx]] # start from the last node - all_node_info = {cur_node_list[0]: {'chunk_dim': end_dim, 'fix_dim': []}} - + all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} + while len(cur_node_list) > 0: next_node_list = [] for cur_node in cur_node_list: # get cur node info - cur_node_chunk_dim = all_node_info[cur_node]['chunk_dim'] - cur_node_fix_dim = all_node_info[cur_node]['fix_dim'] + cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] + cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] cur_node_idx = _find_idx_by_name(cur_node.name, index_tracer.nodes_list) if cur_node_chunk_dim: - cur_node_compute = index_tracer._find_compute_trace_from_node(cur_node) - cur_node_source = index_tracer._find_source_trace_from_node(cur_node) + cur_node_compute = index_tracer._find_compute_trace_from_node( + cur_node + ) + cur_node_source = index_tracer._find_source_trace_from_node( + cur_node + ) else: cur_node_compute = cur_node_source = None - + # get all valid args arg_list = [] for arg in cur_node.args: @@ -1032,20 +1055,33 @@ class FlowTracer(object): if _is_non_compute_node(arg): continue arg_list.append(arg) - flow_flag = self._assgin_single_node_flow(arg, start_idx, end_idx, - inputs, index_tracer, cur_node_chunk_dim, - cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info, - next_node_list) + flow_flag = self._assgin_single_node_flow( + arg, + start_idx, + end_idx, + inputs, + index_tracer, + cur_node_chunk_dim, + cur_node_compute, + cur_node_source, + cur_node_fix_dim, + all_node_info, + next_node_list, + ) if flow_flag == False: return None - + if len(arg_list) == 2: if any(i in cur_node.name for i in ["add", "mul"]): for arg in arg_list: - if not (start_idx <= _find_idx_by_name(arg.name, index_tracer.nodes_list) < end_idx): + if not ( + start_idx + <= _find_idx_by_name(arg.name, index_tracer.nodes_list) + < end_idx + ): continue - arg_chunk_dim = all_node_info[arg]['chunk_dim'] - arg_fix_dim = all_node_info[arg]['fix_dim'] + arg_chunk_dim = all_node_info[arg]["chunk_dim"] + arg_fix_dim = all_node_info[arg]["fix_dim"] arg_shape = _get_node_shape(arg) # add all dim as fix dim except chunk dim for i, shape in enumerate(arg_shape): @@ -1061,7 +1097,7 @@ class FlowTracer(object): else: raise NotImplementedError() cur_node_list = next_node_list - + inputs_dim = [] remove_inputs = [] for input_node in inputs: @@ -1071,7 +1107,7 @@ class FlowTracer(object): continue user_idx = _find_idx_by_name(user.name, self.node_list) if start_idx <= user_idx <= end_idx: - chunk_dim = all_node_info[user]['chunk_dim'] + chunk_dim = all_node_info[user]["chunk_dim"] if chunk_dim is not None: input_dict[user_idx] = chunk_dim if len(input_dict) == 0: @@ -1081,7 +1117,7 @@ class FlowTracer(object): for i in remove_inputs: if i in inputs: inputs.remove(i) - + chunk_info = { "region": (start_idx, end_idx), "inputs": inputs, @@ -1091,7 +1127,7 @@ class FlowTracer(object): "outputs_dim": end_dim, "args": {}, } - + # we need to log input nodes to avoid deleteing them in the loop non_chunk_inputs = _find_chunk_all_input_nodes( self.node_list[start_idx : end_idx + 1] @@ -1129,7 +1165,7 @@ class MemoryEstimator(object): def _add_active_node(self, n, active_list): new_active = self._get_output_node(n)[1] - if n.op == 'placeholder': + if n.op == "placeholder": new_active.append(n.name) for i in new_active: if i not in active_list: @@ -1168,12 +1204,16 @@ class MemoryEstimator(object): for i in delete_node: if i in active_list: active_list.remove(i) - - def _get_chunk_inputs_size(self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx): + + def _get_chunk_inputs_size( + self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx + ): nodes_to_delete = [] for chunk_input in chunk_inputs + chunk_inputs_non_chunk: chunk_input_users = chunk_input.users.keys() - chunk_input_users_idx = [_find_idx_by_name(i.name, node_list) for i in chunk_input_users] + chunk_input_users_idx = [ + _find_idx_by_name(i.name, node_list) for i in chunk_input_users + ] if all(i <= chunk_end_idx for i in chunk_input_users_idx): if chunk_input not in nodes_to_delete: nodes_to_delete.append(chunk_input) @@ -1226,7 +1266,9 @@ class MemoryEstimator(object): for (input_node, input_node_dim) in zip(chunk_inputs, chunk_inputs_dim): for k, v in input_node_dim.items(): # TODO: inherit dim should be list too, int now - inherit_dim = self.index_tracer._find_inherit_dim(input_node, v, self.index_tracer.nodes_list[k]) + inherit_dim = self.index_tracer._find_inherit_dim( + input_node, v, self.index_tracer.nodes_list[k] + ) if k == _find_idx_by_name(node.name, self.index_tracer.nodes_list): chunk_ratio = float(chunk_size) / node_shape[inherit_dim] return chunk_ratio @@ -1234,7 +1276,7 @@ class MemoryEstimator(object): if k in source and inherit_dim in source[k]: chunk_ratio = float(chunk_size) / node_shape[dim] return chunk_ratio - return 1. + return 1.0 def _get_chunk_delete_node_size( self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names @@ -1295,7 +1337,7 @@ class MemoryEstimator(object): chunk_ratio = 1 # use it to estimate chunk mem chunk_size = 1 chunk_inputs_names = [] - + if use_chunk: chunk_regions = [i["region"] for i in chunk_infos] chunk_starts = [i[0] for i in chunk_regions] @@ -1313,12 +1355,17 @@ class MemoryEstimator(object): if use_chunk and idx in chunk_starts: chunk_within = True chunk_region_idx = chunk_starts.index(idx) - act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2) + act_memory += self._get_output_node_size( + chunk_outputs[chunk_region_idx] + ) / (1024**2) # determine chunk ratio for current node if chunk_within: chunk_ratio = self._get_chunk_ratio( - node, chunk_inputs[chunk_region_idx], chunk_inputs_dim[chunk_region_idx], chunk_size + node, + chunk_inputs[chunk_region_idx], + chunk_inputs_dim[chunk_region_idx], + chunk_size, ) # if node is placeholder, just add the size of the node @@ -1353,18 +1400,18 @@ class MemoryEstimator(object): / (1024**2) ) # delete unused vars not in chunk_input_list - # we can't delete input nodes until chunk ends + # we can't delete input nodes until chunk ends if chunk_within: act_memory -= self._get_chunk_delete_node_size( node, user_to_last_uses_no_free_var, chunk_ratio, - chunk_inputs_names + chunk_inputs_names, ) / (1024**2) else: - act_memory -= (self._get_delete_node_size( + act_memory -= self._get_delete_node_size( node, user_to_last_uses_no_free_var, chunk_inputs_names - ) / (1024**2)) + ) / (1024**2) # log active node, only effective without chunk self._add_active_node(node, active_node_list) @@ -1376,11 +1423,11 @@ class MemoryEstimator(object): self._get_output_node_size(node) * chunk_ratio / (1024**2) ) act_memory -= self._get_chunk_inputs_size( - chunk_inputs[chunk_region_idx], - chunk_inputs_non_chunk[chunk_region_idx], + chunk_inputs[chunk_region_idx], + chunk_inputs_non_chunk[chunk_region_idx], node_list, - chunk_regions[chunk_region_idx][1] - ) / (1024**2) + chunk_regions[chunk_region_idx][1], + ) / (1024**2) chunk_within = False chunk_ratio = 1 chunk_region_idx = None @@ -1436,7 +1483,7 @@ class ChunkRegionSearch(object): active_node_num = [len(i) for i in active_node] min_active_node_num = min(active_node_num[free_var_num:]) threshold = max(free_var_num, min_active_node_num) - + # from peak_node to free_var inside_flag = False chunk_region_start = free_var_num @@ -1494,7 +1541,12 @@ class ChunkRegionSearch(object): continue for start_node, start_trace in start_traces.items(): for start_dim, start_trace_idx in enumerate(start_trace["idx"]): - if start_idx == 199 and end_idx == 229 and start_dim == 2 and end_dim == 2: + if ( + start_idx == 199 + and end_idx == 229 + and start_dim == 2 + and end_dim == 2 + ): print(1) self.flow_tracer.flow_search( start_idx, start_dim, end_idx, end_dim, self.index_tracer @@ -1576,7 +1628,7 @@ class ChunkRegionSearch(object): max_region_range = 0 best_region = None return best_region - + def _is_legal_region(self, cur_chunk_info, chunk_infos): (chunk_region_start, chunk_region_end) = cur_chunk_info["region"] if cur_chunk_info in chunk_infos: @@ -1585,11 +1637,13 @@ class ChunkRegionSearch(object): return False for i in chunk_infos: region = i["region"] - if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) - or (chunk_region_start < region[0] and chunk_region_end < region[0])): + if not ( + (chunk_region_start > region[1] and chunk_region_end > region[1]) + or (chunk_region_start < region[0] and chunk_region_end < region[0]) + ): return False return True - + def _step_search(self, mem_peak, active_node, chunk_regions): peak_node = self._find_peak_node(mem_peak) max_chunk_region = self._search_max_chunk_region( @@ -1600,7 +1654,9 @@ class ChunkRegionSearch(object): possible_chunk_regions = self._search_possible_chunk_regions( max_chunk_region, peak_node ) - best_chunk_region = self._search_best_chunk_region(possible_chunk_regions, chunk_regions) + best_chunk_region = self._search_best_chunk_region( + possible_chunk_regions, chunk_regions + ) return best_chunk_region def _stop_search(self, init_mem_peak, mem_peak): @@ -1667,7 +1723,11 @@ def _gen_loop_end( chunk_slice = _gen_chunk_slice_dim( chunk_outputs_dim, "chunk_idx", chunk_output_shape ) - context = " chunk_result%s = %s; %s = None\n" % (chunk_slice, chunk_outputs_name, chunk_outputs_name) + context = " chunk_result%s = %s; %s = None\n" % ( + chunk_slice, + chunk_outputs_name, + chunk_outputs_name, + ) context += ( chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None" )