diff --git a/colossalai/autochunk/reorder_graph.py b/colossalai/autochunk/reorder_graph.py new file mode 100644 index 000000000..7b9f4a20d --- /dev/null +++ b/colossalai/autochunk/reorder_graph.py @@ -0,0 +1,108 @@ +from .trace_index import TraceIndex +from .utils import find_idx_by_name + + +class ReorderGraph(object): + def __init__(self, index_tracer: TraceIndex) -> None: + self.index_tracer = index_tracer + self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))} + + def _get_reorder_map(self, chunk_info): + reorder_map = {i: i for i in range(len(self.index_tracer.node_list))} + + chunk_region_start = chunk_info["region"][0] + chunk_region_end = chunk_info["region"][1] + chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] + chunk_prepose_nodes_idx = [ + find_idx_by_name(i.name, self.index_tracer.node_list) for i in chunk_prepose_nodes + ] + # put prepose nodes ahead + for idx, n in enumerate(chunk_prepose_nodes): + n_idx = chunk_prepose_nodes_idx[idx] + reorder_map[n_idx] = chunk_region_start + idx + # put other nodes after prepose nodes + for n in self.index_tracer.node_list[chunk_region_start : chunk_region_end + 1]: + if n in chunk_prepose_nodes: + continue + n_idx = find_idx_by_name(n.name, self.index_tracer.node_list) + pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) + reorder_map[n_idx] = n_idx + pos + + return reorder_map + + def _reorder_chunk_info(self, chunk_info, reorder_map): + # update chunk info + chunk_info["region"] = ( + chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]), + chunk_info["region"][1], + ) + new_inputs_dim = [] + for idx, input_dim in enumerate(chunk_info["inputs_dim"]): + new_input_dim = {} + for k, v in input_dim.items(): + new_input_dim[reorder_map[k]] = v + new_inputs_dim.append(new_input_dim) + chunk_info["inputs_dim"] = new_inputs_dim + return chunk_info + + def _update_all_reorder_map(self, reorder_map): + for origin_idx, map_idx in self.all_reorder_map.items(): + self.all_reorder_map[origin_idx] = reorder_map[map_idx] + + def _reorder_self_node_list(self, reorder_map): + new_node_list = [None for _ in range(len(self.index_tracer.node_list))] + for old_idx, new_idx in reorder_map.items(): + new_node_list[new_idx] = self.index_tracer.node_list[old_idx] + self.index_tracer.node_list = new_node_list + + def _reorder_idx_trace(self, reorder_map): + # reorder list + new_idx_trace_list = [None for _ in range(len(self.index_tracer.idx_trace_list))] + for old_idx, new_idx in reorder_map.items(): + new_idx_trace_list[new_idx] = self.index_tracer.idx_trace_list[old_idx] + self.index_tracer.idx_trace_list = new_idx_trace_list + # update compute + for idx_trace in self.index_tracer.idx_trace_list: + compute = idx_trace["compute"] + for dim_compute in compute: + for idx, i in enumerate(dim_compute): + dim_compute[idx] = reorder_map[i] + # update source + for idx_trace in self.index_tracer.idx_trace_list: + source = idx_trace["source"] + for dim_idx, dim_source in enumerate(source): + new_dim_source = {} + for k, v in dim_source.items(): + new_dim_source[reorder_map[k]] = v + source[dim_idx] = new_dim_source + + def reorder_all(self, chunk_info): + if chunk_info is None: + return chunk_info + if len(chunk_info["args"]["prepose_nodes"]) == 0: + return chunk_info + reorder_map = self._get_reorder_map(chunk_info) + self._update_all_reorder_map(reorder_map) + self._reorder_idx_trace(reorder_map) + self._reorder_self_node_list(reorder_map) + chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) + return chunk_info + + def reorder_node_list(self, node_list): + new_node_list = [None for _ in range(len(node_list))] + for old_idx, new_idx in self.all_reorder_map.items(): + new_node_list[new_idx] = node_list[old_idx] + return new_node_list + + def tmp_reorder(self, node_list, chunk_info): + if len(chunk_info["args"]["prepose_nodes"]) == 0: + return node_list, chunk_info + reorder_map = self._get_reorder_map(chunk_info) + + # new tmp node list + new_node_list = [None for _ in range(len(node_list))] + for old_idx, new_idx in reorder_map.items(): + new_node_list[new_idx] = node_list[old_idx] + + chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) + return new_node_list, chunk_info diff --git a/colossalai/autochunk/trace_index.py b/colossalai/autochunk/trace_index.py index 103a05dad..3ac0d7f84 100644 --- a/colossalai/autochunk/trace_index.py +++ b/colossalai/autochunk/trace_index.py @@ -979,109 +979,3 @@ class TraceIndex(object): ) chunk_info["reshape_size"] = reshape_size return chunk_info - - -class ReorderGraph(object): - def __init__(self, index_tracer: TraceIndex) -> None: - self.index_tracer = index_tracer - self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))} - - def _get_reorder_map(self, chunk_info): - reorder_map = {i: i for i in range(len(self.index_tracer.node_list))} - - chunk_region_start = chunk_info["region"][0] - chunk_region_end = chunk_info["region"][1] - chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] - chunk_prepose_nodes_idx = [ - find_idx_by_name(i.name, self.index_tracer.node_list) for i in chunk_prepose_nodes - ] - # put prepose nodes ahead - for idx, n in enumerate(chunk_prepose_nodes): - n_idx = chunk_prepose_nodes_idx[idx] - reorder_map[n_idx] = chunk_region_start + idx - # put other nodes after prepose nodes - for n in self.index_tracer.node_list[chunk_region_start : chunk_region_end + 1]: - if n in chunk_prepose_nodes: - continue - n_idx = find_idx_by_name(n.name, self.index_tracer.node_list) - pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) - reorder_map[n_idx] = n_idx + pos - - return reorder_map - - def _reorder_chunk_info(self, chunk_info, reorder_map): - # update chunk info - chunk_info["region"] = ( - chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]), - chunk_info["region"][1], - ) - new_inputs_dim = [] - for idx, input_dim in enumerate(chunk_info["inputs_dim"]): - new_input_dim = {} - for k, v in input_dim.items(): - new_input_dim[reorder_map[k]] = v - new_inputs_dim.append(new_input_dim) - chunk_info["inputs_dim"] = new_inputs_dim - return chunk_info - - def _update_all_reorder_map(self, reorder_map): - for origin_idx, map_idx in self.all_reorder_map.items(): - self.all_reorder_map[origin_idx] = reorder_map[map_idx] - - def _reorder_self_node_list(self, reorder_map): - new_node_list = [None for _ in range(len(self.index_tracer.node_list))] - for old_idx, new_idx in reorder_map.items(): - new_node_list[new_idx] = self.index_tracer.node_list[old_idx] - self.index_tracer.node_list = new_node_list - - def _reorder_idx_trace(self, reorder_map): - # reorder list - new_idx_trace_list = [None for _ in range(len(self.index_tracer.idx_trace_list))] - for old_idx, new_idx in reorder_map.items(): - new_idx_trace_list[new_idx] = self.index_tracer.idx_trace_list[old_idx] - self.index_tracer.idx_trace_list = new_idx_trace_list - # update compute - for idx_trace in self.index_tracer.idx_trace_list: - compute = idx_trace["compute"] - for dim_compute in compute: - for idx, i in enumerate(dim_compute): - dim_compute[idx] = reorder_map[i] - # update source - for idx_trace in self.index_tracer.idx_trace_list: - source = idx_trace["source"] - for dim_idx, dim_source in enumerate(source): - new_dim_source = {} - for k, v in dim_source.items(): - new_dim_source[reorder_map[k]] = v - source[dim_idx] = new_dim_source - - def reorder_all(self, chunk_info): - if chunk_info is None: - return chunk_info - if len(chunk_info["args"]["prepose_nodes"]) == 0: - return chunk_info - reorder_map = self._get_reorder_map(chunk_info) - self._update_all_reorder_map(reorder_map) - self._reorder_idx_trace(reorder_map) - self._reorder_self_node_list(reorder_map) - chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) - return chunk_info - - def reorder_node_list(self, node_list): - new_node_list = [None for _ in range(len(node_list))] - for old_idx, new_idx in self.all_reorder_map.items(): - new_node_list[new_idx] = node_list[old_idx] - return new_node_list - - def tmp_reorder(self, node_list, chunk_info): - if len(chunk_info["args"]["prepose_nodes"]) == 0: - return node_list, chunk_info - reorder_map = self._get_reorder_map(chunk_info) - - # new tmp node list - new_node_list = [None for _ in range(len(node_list))] - for old_idx, new_idx in reorder_map.items(): - new_node_list[new_idx] = node_list[old_idx] - - chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) - return new_node_list, chunk_info