diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 67f764a31..eee357073 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -158,11 +158,11 @@ class SearchChunk(object): end_trace = output_trace[end_idx] end_node = self.trace_indice.node_list[end_idx] chunk_infos = [] - for end_dim, _ in enumerate(end_trace["idx"]): + for end_dim, _ in enumerate(end_trace["indice"]): if len(start_traces) > 1: continue for start_node, start_trace in start_traces.items(): - for start_dim, _ in enumerate(start_trace["idx"]): + for start_dim, _ in enumerate(start_trace["indice"]): # dim size cannot be 1 if ( get_node_shape(end_node)[end_dim] == 1 diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index 669bfb30a..791e5a36e 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -19,12 +19,12 @@ class TraceIndice(object): for n in self.node_list: if get_node_shape(n) != None: cur_trace = { - "idx": [None for _ in range(len(get_node_shape(n)))], + "indice": [None for _ in range(len(get_node_shape(n)))], "compute": [[] for _ in range(len(get_node_shape(n)))], "source": [{} for _ in range(len(get_node_shape(n)))], } else: - cur_trace = {"idx": [], "compute": [], "source": []} + cur_trace = {"indice": [], "compute": [], "source": []} indice_trace_list.append(cur_trace) return indice_trace_list @@ -39,12 +39,12 @@ class TraceIndice(object): return self.indice_count def _del_dim(self, idx, dim_idx): - self.indice_trace_list[idx]["idx"].pop(dim_idx) + self.indice_trace_list[idx]["indice"].pop(dim_idx) self.indice_trace_list[idx]["compute"].pop(dim_idx) self.indice_trace_list[idx]["source"].pop(dim_idx) def _add_dim(self, node_idx, dim_idx): - self.indice_trace_list[node_idx]["idx"].insert(dim_idx, self._add_indice()) + self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice()) self.indice_trace_list[node_idx]["compute"].insert(dim_idx, []) self.indice_trace_list[node_idx]["source"].insert(dim_idx, {}) @@ -58,7 +58,7 @@ class TraceIndice(object): node_to_dim = self._transform_indice(node_to, node_to_dim) node_from_trace = self._find_trace_from_node(node_from) node_to_trace = self._find_trace_from_node(node_to) - node_to_trace["idx"][node_to_dim] = node_from_trace["idx"][node_from_dim] + node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim] node_to_trace["compute"][node_to_dim] = copy.deepcopy( node_from_trace["compute"][node_from_dim] ) @@ -181,7 +181,7 @@ class TraceIndice(object): idx (list): idx of the node """ node_idx = find_idx_by_name(node.name, self.node_list) - return self.indice_trace_list[node_idx]["idx"] + return self.indice_trace_list[node_idx]["indice"] def _find_compute_trace_from_node(self, node): """ @@ -195,7 +195,7 @@ class TraceIndice(object): node_idx = find_idx_by_name(node.name, self.node_list) return self.indice_trace_list[node_idx]["compute"] - def _assign_index_as_input(self, node, node_idx, input_node=None): + def _assign_indice_as_input(self, node, node_idx, input_node=None): """ Assign node's trace as its input node. @@ -206,10 +206,10 @@ class TraceIndice(object): if input_node == None: input_node = node.args[0] input_node_idx = find_idx_by_name(input_node.name, self.node_list) - input_node_idx_trace = self.indice_trace_list[input_node_idx]["idx"] + input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"] new_idx_trace = copy.deepcopy(input_node_idx_trace) - self.indice_trace_list[node_idx]["idx"] = new_idx_trace + self.indice_trace_list[node_idx]["indice"] = new_idx_trace self._inherit_all_computation(input_node, node) @@ -225,7 +225,7 @@ class TraceIndice(object): new_trace = [] for _ in shape: new_trace.append(self._add_indice()) - self.indice_trace_list[node_idx]["idx"] = new_trace + self.indice_trace_list[node_idx]["indice"] = new_trace def _assign_transpose_indice(self, node, node_idx): """ @@ -240,7 +240,7 @@ class TraceIndice(object): input_node = node.args[0] tranpose_dim = node.args[1:] - self._assign_index_as_input(node, node_idx, input_node) + self._assign_indice_as_input(node, node_idx, input_node) self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0]) self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1]) @@ -257,7 +257,7 @@ class TraceIndice(object): permute_dim = node.args[1:] input_node = node.args[0] - self._assign_index_as_input(node, node_idx, input_node) + self._assign_indice_as_input(node, node_idx, input_node) for idx, d in enumerate(permute_dim): self._inherit_indice(input_node, d, node, idx) @@ -278,7 +278,7 @@ class TraceIndice(object): else: input_node, weight, bias = node.args - self._assign_index_as_input(node, node_idx) + self._assign_indice_as_input(node, node_idx) self._inherit_indice(weight, 1, node, -1) self._mark_computation(node, node_idx, [-1]) @@ -301,7 +301,7 @@ class TraceIndice(object): matmul_left, matmul_right = node.args assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right)) - self._assign_index_as_input(node, node_idx, matmul_left) + self._assign_indice_as_input(node, node_idx, matmul_left) self._inherit_indice(matmul_right, -1, node, -1) self._mark_computation_from_node(matmul_right, node, [-1, -2]) @@ -318,7 +318,7 @@ class TraceIndice(object): node (node) node_idx (int) """ - self._assign_index_as_input(node, idx) + self._assign_indice_as_input(node, idx) self._mark_computation(node, idx, [-1]) def _assign_elementwise_indice(self, node, idx): @@ -331,7 +331,7 @@ class TraceIndice(object): node (node) node_idx (int) """ - self._assign_index_as_input(node, idx) + self._assign_indice_as_input(node, idx) nodes_in = [] for node_in in node.args: if type(node_in) == type(node): @@ -346,7 +346,7 @@ class TraceIndice(object): self._mark_indice_equal(nodes_in[0], i, nodes_in[1], i) def _assgin_no_change_indice(self, node, idx): - self._assign_index_as_input(node, idx) + self._assign_indice_as_input(node, idx) for node_in in node.args: if type(node_in) == type(node): self._mark_computation_from_node(node_in, node) @@ -398,7 +398,7 @@ class TraceIndice(object): node (node) node_idx (int) """ - self._assign_index_as_input(node, idx) + self._assign_indice_as_input(node, idx) self._mark_computation(node, idx, [node.kwargs["dim"]]) def _assign_unsqueeze_indice(self, node, node_idx): @@ -411,7 +411,7 @@ class TraceIndice(object): node_idx (int) """ self._del_dim(node_idx, -1) - self._assign_index_as_input(node, node_idx) + self._assign_indice_as_input(node, node_idx) self._add_dim(node_idx, node.args[1]) def _assign_dropout_indice(self, node, node_idx): @@ -423,7 +423,7 @@ class TraceIndice(object): node (node) node_idx (int) """ - self._assign_index_as_input(node, node_idx) + self._assign_indice_as_input(node, node_idx) def _assign_ones_like_indice(self, node, node_idx): """ @@ -497,7 +497,7 @@ class TraceIndice(object): # get new index origin_trace = self._find_indice_trace_from_node(origin_node) - self._assign_index_as_input(node, node_idx, origin_node) + self._assign_indice_as_input(node, node_idx, origin_node) dim_from.reverse() for i in dim_from: self._del_dim(node_idx, i) @@ -516,7 +516,7 @@ class TraceIndice(object): view_dict = { "idx_from": [origin_trace[i] for i in dim_from], "dim_from": dim_from, - "idx_to": [self.indice_trace_list[node_idx]["idx"][i] for i in dim_to], + "idx_to": [self.indice_trace_list[node_idx]["indice"][i] for i in dim_to], "dim_to": dim_to, } self.indice_view_list[node] = view_dict @@ -528,9 +528,9 @@ class TraceIndice(object): merge_to = min(idx) merge_from = max(idx) for trace in self.indice_trace_list: - if merge_from in trace["idx"]: - trace["idx"] = [ - merge_to if i == merge_from else i for i in trace["idx"] + if merge_from in trace["indice"]: + trace["indice"] = [ + merge_to if i == merge_from else i for i in trace["indice"] ] def trace_index(self):