|
|
|
@ -111,21 +111,6 @@ class TraceIndice(object):
|
|
|
|
|
if j not in node_to_compute[i]: |
|
|
|
|
node_to_compute[i].append(j) |
|
|
|
|
|
|
|
|
|
def _mark_indice_equal(self, node1, dim1, node2, dim2): |
|
|
|
|
""" |
|
|
|
|
Mark 2 indice to be equal. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
idx1 (int): indice count. |
|
|
|
|
idx2 (int): indice count. |
|
|
|
|
""" |
|
|
|
|
# node1_idx = _find_idx_by_name(node1.name, self.nodes_list) |
|
|
|
|
# node2_idx = _find_idx_by_name(node2.name, self.nodes_list) |
|
|
|
|
# if node1_idx > node2_idx: |
|
|
|
|
# self._add_source(node2, dim2, node1, dim1) |
|
|
|
|
# else: |
|
|
|
|
# self._add_source(node1, dim1, node2, dim2) |
|
|
|
|
|
|
|
|
|
def _mark_computation(self, node, idx, dim): |
|
|
|
|
""" |
|
|
|
|
Mark some dims of node as computed. |
|
|
|
@ -273,19 +258,14 @@ class TraceIndice(object):
|
|
|
|
|
node_idx (int) |
|
|
|
|
""" |
|
|
|
|
if len(node.args) == 2: |
|
|
|
|
input_node, weight = node.args |
|
|
|
|
bias = None |
|
|
|
|
_, weight = node.args |
|
|
|
|
else: |
|
|
|
|
input_node, weight, bias = node.args |
|
|
|
|
_, weight, _ = node.args |
|
|
|
|
|
|
|
|
|
self._assign_indice_as_input(node, node_idx) |
|
|
|
|
self._inherit_indice(weight, 1, node, -1) |
|
|
|
|
|
|
|
|
|
self._mark_computation(node, node_idx, [-1]) |
|
|
|
|
self._mark_indice_equal(input_node, -1, weight, 0) |
|
|
|
|
|
|
|
|
|
if bias: |
|
|
|
|
self._mark_indice_equal(input_node, -1, bias, 0) |
|
|
|
|
|
|
|
|
|
def _assign_matmul_indice(self, node, node_idx): |
|
|
|
|
""" |
|
|
|
@ -306,7 +286,6 @@ class TraceIndice(object):
|
|
|
|
|
|
|
|
|
|
self._mark_computation_from_node(matmul_right, node, [-1, -2]) |
|
|
|
|
self._mark_computation(node, node_idx, [-1]) |
|
|
|
|
self._mark_indice_equal(matmul_left, -1, matmul_right, -2) |
|
|
|
|
|
|
|
|
|
def _assign_layernorm_indice(self, node, idx): |
|
|
|
|
""" |
|
|
|
@ -338,12 +317,6 @@ class TraceIndice(object):
|
|
|
|
|
nodes_in.append(node_in) |
|
|
|
|
self._mark_computation_from_node(node_in, node) |
|
|
|
|
assert len(nodes_in) <= 2 |
|
|
|
|
if len(nodes_in) == 2: |
|
|
|
|
node_in0_shape = get_node_shape(nodes_in[0]) |
|
|
|
|
node_in1_shape = get_node_shape(nodes_in[1]) |
|
|
|
|
for i in range(-1, -min(len(node_in0_shape), len(node_in1_shape)) - 1, -1): |
|
|
|
|
if node_in0_shape[i] == node_in1_shape[i]: |
|
|
|
|
self._mark_indice_equal(nodes_in[0], i, nodes_in[1], i) |
|
|
|
|
|
|
|
|
|
def _assgin_no_change_indice(self, node, idx): |
|
|
|
|
self._assign_indice_as_input(node, idx) |
|
|
|
|