From 6685a9d022a912ab3d0a57486b045b92b3f681ce Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 15:53:24 +0800 Subject: [PATCH] seperate non chunk input --- colossalai/autochunk/index_tracer.py | 35 +++++++++++++++++----------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/colossalai/autochunk/index_tracer.py b/colossalai/autochunk/index_tracer.py index 206d2edbd..202044763 100644 --- a/colossalai/autochunk/index_tracer.py +++ b/colossalai/autochunk/index_tracer.py @@ -839,7 +839,7 @@ class IndexTracer(object): inputs.remove(i) return inputs, inputs_dim - def _set_prepose_nodes(self, all_node_info, start_idx, end_idx): + def _get_prepose_nodes(self, all_node_info, start_idx, end_idx): # get all possible prepose nodes maybe_prepose_nodes = [] for node, node_info in all_node_info.items(): @@ -902,7 +902,19 @@ class IndexTracer(object): prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.node_list)) return prepose_nodes - + + def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): + # we need to log input nodes to avoid deleteing them in the loop + chunk_node_list = self.node_list[start_idx : end_idx + 1] + # also need to get some prepose node's arg out of non_chunk_inputs + for n in chunk_info["args"]["prepose_nodes"]: + chunk_node_list.remove(n) + non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list) + for i in non_chunk_inputs: + if i not in chunk_info["inputs"]: + chunk_info["inputs_non_chunk"].append(i) + return chunk_info + def flow_search(self, start_idx, start_dim, end_idx, end_dim): inputs, outputs = find_chunk_compute_input_and_output_nodes( self.node_list[start_idx : end_idx + 1] @@ -917,7 +929,9 @@ class IndexTracer(object): return None # get input nodes' chunk dim - inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info) + inputs, inputs_dim = self._get_input_nodes_dim( + inputs, start_idx, end_idx, all_node_info + ) if inputs is None: return None @@ -933,17 +947,12 @@ class IndexTracer(object): } # move useless nodes ahead of loop - chunk_info["args"]["prepose_nodes"] = self._set_prepose_nodes(all_node_info, start_idx, end_idx) + chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes( + all_node_info, start_idx, end_idx + ) - # we need to log input nodes to avoid deleteing them in the loop - chunk_node_list = self.node_list[start_idx : end_idx + 1] - # also need to get some prepose node's arg out of non_chunk_inputs - for n in chunk_info["args"]["prepose_nodes"]: - chunk_node_list.remove(n) - non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list) - for i in non_chunk_inputs: - if i not in chunk_info["inputs"]: - chunk_info["inputs_non_chunk"].append(i) + # find non chunk inputs + chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx) # reassgin reshape size, some size may have changed due to chunk chunk_info = self._reassgin_reshape_size(chunk_info)