diff --git a/colossalai/autochunk/index_tracer.py b/colossalai/autochunk/index_tracer.py index 221217e2d..206d2edbd 100644 --- a/colossalai/autochunk/index_tracer.py +++ b/colossalai/autochunk/index_tracer.py @@ -839,36 +839,7 @@ class IndexTracer(object): inputs.remove(i) return inputs, inputs_dim - def flow_search(self, start_idx, start_dim, end_idx, end_dim): - inputs, outputs = find_chunk_compute_input_and_output_nodes( - self.node_list[start_idx : end_idx + 1] - ) - # only single ouput - if len(outputs) > 1: - return None - - # get every node's chunk dim and fix dim - all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) - if all_node_info is None: - return None - - # get input nodes' chunk dim - inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info) - if inputs is None: - return None - - chunk_info = { - "region": (start_idx, end_idx), - "inputs": inputs, - "inputs_non_chunk": [], - "inputs_dim": inputs_dim, - "outputs": outputs, - "outputs_dim": end_dim, - "node_chunk_dim": all_node_info, - "args": {}, - } - - # move useless nodes ahead of loop + def _set_prepose_nodes(self, all_node_info, start_idx, end_idx): # get all possible prepose nodes maybe_prepose_nodes = [] for node, node_info in all_node_info.items(): @@ -929,12 +900,45 @@ class IndexTracer(object): maybe_prepose_nodes.remove(n) # sort by index prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.node_list)) - chunk_info["args"]["prepose_nodes"] = prepose_nodes + + return prepose_nodes + + def flow_search(self, start_idx, start_dim, end_idx, end_dim): + inputs, outputs = find_chunk_compute_input_and_output_nodes( + self.node_list[start_idx : end_idx + 1] + ) + # only single ouput + if len(outputs) > 1: + return None + + # get every node's chunk dim and fix dim + all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) + if all_node_info is None: + return None + + # get input nodes' chunk dim + inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info) + if inputs is None: + return None + + chunk_info = { + "region": (start_idx, end_idx), + "inputs": inputs, + "inputs_non_chunk": [], + "inputs_dim": inputs_dim, + "outputs": outputs, + "outputs_dim": end_dim, + "node_chunk_dim": all_node_info, + "args": {}, + } + + # move useless nodes ahead of loop + chunk_info["args"]["prepose_nodes"] = self._set_prepose_nodes(all_node_info, start_idx, end_idx) # we need to log input nodes to avoid deleteing them in the loop chunk_node_list = self.node_list[start_idx : end_idx + 1] # also need to get some prepose node's arg out of non_chunk_inputs - for n in prepose_nodes: + for n in chunk_info["args"]["prepose_nodes"]: chunk_node_list.remove(n) non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list) for i in non_chunk_inputs: