diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index 9ad2649e7..a72fd775b 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -111,21 +111,6 @@ class TraceIndice(object): if j not in node_to_compute[i]: node_to_compute[i].append(j) - def _mark_indice_equal(self, node1, dim1, node2, dim2): - """ - Mark 2 indice to be equal. - - Args: - idx1 (int): indice count. - idx2 (int): indice count. - """ - # node1_idx = _find_idx_by_name(node1.name, self.nodes_list) - # node2_idx = _find_idx_by_name(node2.name, self.nodes_list) - # if node1_idx > node2_idx: - # self._add_source(node2, dim2, node1, dim1) - # else: - # self._add_source(node1, dim1, node2, dim2) - def _mark_computation(self, node, idx, dim): """ Mark some dims of node as computed. @@ -273,19 +258,14 @@ class TraceIndice(object): node_idx (int) """ if len(node.args) == 2: - input_node, weight = node.args - bias = None + _, weight = node.args else: - input_node, weight, bias = node.args + _, weight, _ = node.args self._assign_indice_as_input(node, node_idx) self._inherit_indice(weight, 1, node, -1) self._mark_computation(node, node_idx, [-1]) - self._mark_indice_equal(input_node, -1, weight, 0) - - if bias: - self._mark_indice_equal(input_node, -1, bias, 0) def _assign_matmul_indice(self, node, node_idx): """ @@ -306,7 +286,6 @@ class TraceIndice(object): self._mark_computation_from_node(matmul_right, node, [-1, -2]) self._mark_computation(node, node_idx, [-1]) - self._mark_indice_equal(matmul_left, -1, matmul_right, -2) def _assign_layernorm_indice(self, node, idx): """ @@ -338,12 +317,6 @@ class TraceIndice(object): nodes_in.append(node_in) self._mark_computation_from_node(node_in, node) assert len(nodes_in) <= 2 - if len(nodes_in) == 2: - node_in0_shape = get_node_shape(nodes_in[0]) - node_in1_shape = get_node_shape(nodes_in[1]) - for i in range(-1, -min(len(node_in0_shape), len(node_in1_shape)) - 1, -1): - if node_in0_shape[i] == node_in1_shape[i]: - self._mark_indice_equal(nodes_in[0], i, nodes_in[1], i) def _assgin_no_change_indice(self, node, idx): self._assign_indice_as_input(node, idx)