from .trace_indice import TraceIndice from .utils import find_idx_by_name class ReorderGraph(object): """ Reorder node list and indice trace list """ def __init__(self, trace_indice: TraceIndice) -> None: self.trace_indice = trace_indice self.all_reorder_map = { i: i for i in range(len(self.trace_indice.indice_trace_list)) } def _get_reorder_map(self, chunk_info): reorder_map = {i: i for i in range(len(self.trace_indice.node_list))} chunk_region_start = chunk_info["region"][0] chunk_region_end = chunk_info["region"][1] chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] chunk_prepose_nodes_idx = [ find_idx_by_name(i.name, self.trace_indice.node_list) for i in chunk_prepose_nodes ] # put prepose nodes ahead for idx, n in enumerate(chunk_prepose_nodes): n_idx = chunk_prepose_nodes_idx[idx] reorder_map[n_idx] = chunk_region_start + idx # put other nodes after prepose nodes for n in self.trace_indice.node_list[chunk_region_start : chunk_region_end + 1]: if n in chunk_prepose_nodes: continue n_idx = find_idx_by_name(n.name, self.trace_indice.node_list) pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) reorder_map[n_idx] = n_idx + pos return reorder_map def _reorder_chunk_info(self, chunk_info, reorder_map): # update chunk info chunk_info["region"] = ( chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]), chunk_info["region"][1], ) new_inputs_dim = [] for idx, input_dim in enumerate(chunk_info["inputs_dim"]): new_input_dim = {} for k, v in input_dim.items(): new_input_dim[reorder_map[k]] = v new_inputs_dim.append(new_input_dim) chunk_info["inputs_dim"] = new_inputs_dim return chunk_info def _update_all_reorder_map(self, reorder_map): for origin_idx, map_idx in self.all_reorder_map.items(): self.all_reorder_map[origin_idx] = reorder_map[map_idx] def _reorder_self_node_list(self, reorder_map): new_node_list = [None for _ in range(len(self.trace_indice.node_list))] for old_idx, new_idx in reorder_map.items(): new_node_list[new_idx] = self.trace_indice.node_list[old_idx] self.trace_indice.node_list = new_node_list def _reorder_idx_trace(self, reorder_map): # reorder list new_idx_trace_list = [ None for _ in range(len(self.trace_indice.indice_trace_list)) ] for old_idx, new_idx in reorder_map.items(): new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx] self.trace_indice.indice_trace_list = new_idx_trace_list # update compute for idx_trace in self.trace_indice.indice_trace_list: compute = idx_trace["compute"] for dim_compute in compute: for idx, i in enumerate(dim_compute): dim_compute[idx] = reorder_map[i] # update source for idx_trace in self.trace_indice.indice_trace_list: source = idx_trace["source"] for dim_idx, dim_source in enumerate(source): new_dim_source = {} for k, v in dim_source.items(): new_dim_source[reorder_map[k]] = v source[dim_idx] = new_dim_source def reorder_all(self, chunk_info): if chunk_info is None: return chunk_info if len(chunk_info["args"]["prepose_nodes"]) == 0: return chunk_info reorder_map = self._get_reorder_map(chunk_info) self._update_all_reorder_map(reorder_map) self._reorder_idx_trace(reorder_map) self._reorder_self_node_list(reorder_map) chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) return chunk_info def reorder_node_list(self, node_list): new_node_list = [None for _ in range(len(node_list))] for old_idx, new_idx in self.all_reorder_map.items(): new_node_list[new_idx] = node_list[old_idx] return new_node_list def tmp_reorder(self, node_list, chunk_info): if len(chunk_info["args"]["prepose_nodes"]) == 0: return node_list, chunk_info reorder_map = self._get_reorder_map(chunk_info) # new tmp node list new_node_list = [None for _ in range(len(node_list))] for old_idx, new_idx in reorder_map.items(): new_node_list[new_idx] = node_list[old_idx] chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) return new_node_list, chunk_info