|
|
|
@ -111,21 +111,6 @@ class TraceIndice(object):
|
|
|
|
|
if j not in node_to_compute[i]:
|
|
|
|
|
node_to_compute[i].append(j)
|
|
|
|
|
|
|
|
|
|
def _mark_indice_equal(self, node1, dim1, node2, dim2):
|
|
|
|
|
"""
|
|
|
|
|
Mark 2 indice to be equal.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
idx1 (int): indice count.
|
|
|
|
|
idx2 (int): indice count.
|
|
|
|
|
"""
|
|
|
|
|
# node1_idx = _find_idx_by_name(node1.name, self.nodes_list)
|
|
|
|
|
# node2_idx = _find_idx_by_name(node2.name, self.nodes_list)
|
|
|
|
|
# if node1_idx > node2_idx:
|
|
|
|
|
# self._add_source(node2, dim2, node1, dim1)
|
|
|
|
|
# else:
|
|
|
|
|
# self._add_source(node1, dim1, node2, dim2)
|
|
|
|
|
|
|
|
|
|
def _mark_computation(self, node, idx, dim):
|
|
|
|
|
"""
|
|
|
|
|
Mark some dims of node as computed.
|
|
|
|
@ -273,19 +258,14 @@ class TraceIndice(object):
|
|
|
|
|
node_idx (int)
|
|
|
|
|
"""
|
|
|
|
|
if len(node.args) == 2:
|
|
|
|
|
input_node, weight = node.args
|
|
|
|
|
bias = None
|
|
|
|
|
_, weight = node.args
|
|
|
|
|
else:
|
|
|
|
|
input_node, weight, bias = node.args
|
|
|
|
|
_, weight, _ = node.args
|
|
|
|
|
|
|
|
|
|
self._assign_indice_as_input(node, node_idx)
|
|
|
|
|
self._inherit_indice(weight, 1, node, -1)
|
|
|
|
|
|
|
|
|
|
self._mark_computation(node, node_idx, [-1])
|
|
|
|
|
self._mark_indice_equal(input_node, -1, weight, 0)
|
|
|
|
|
|
|
|
|
|
if bias:
|
|
|
|
|
self._mark_indice_equal(input_node, -1, bias, 0)
|
|
|
|
|
|
|
|
|
|
def _assign_matmul_indice(self, node, node_idx):
|
|
|
|
|
"""
|
|
|
|
@ -306,7 +286,6 @@ class TraceIndice(object):
|
|
|
|
|
|
|
|
|
|
self._mark_computation_from_node(matmul_right, node, [-1, -2])
|
|
|
|
|
self._mark_computation(node, node_idx, [-1])
|
|
|
|
|
self._mark_indice_equal(matmul_left, -1, matmul_right, -2)
|
|
|
|
|
|
|
|
|
|
def _assign_layernorm_indice(self, node, idx):
|
|
|
|
|
"""
|
|
|
|
@ -338,12 +317,6 @@ class TraceIndice(object):
|
|
|
|
|
nodes_in.append(node_in)
|
|
|
|
|
self._mark_computation_from_node(node_in, node)
|
|
|
|
|
assert len(nodes_in) <= 2
|
|
|
|
|
if len(nodes_in) == 2:
|
|
|
|
|
node_in0_shape = get_node_shape(nodes_in[0])
|
|
|
|
|
node_in1_shape = get_node_shape(nodes_in[1])
|
|
|
|
|
for i in range(-1, -min(len(node_in0_shape), len(node_in1_shape)) - 1, -1):
|
|
|
|
|
if node_in0_shape[i] == node_in1_shape[i]:
|
|
|
|
|
self._mark_indice_equal(nodes_in[0], i, nodes_in[1], i)
|
|
|
|
|
|
|
|
|
|
def _assgin_no_change_indice(self, node, idx):
|
|
|
|
|
self._assign_indice_as_input(node, idx)
|
|
|
|
|