|
|
|
@ -461,7 +461,7 @@ class TraceIndice(object):
|
|
|
|
|
nodes_in.append(node_in)
|
|
|
|
|
self._inherit_more_indice_from_node_with_exclude(node_in, node)
|
|
|
|
|
|
|
|
|
|
def _assgin_no_change_indice(self, node, idx):
|
|
|
|
|
def _assign_no_change_indice(self, node, idx):
|
|
|
|
|
self._assign_indice_as_input(node, idx)
|
|
|
|
|
for node_in in node.args:
|
|
|
|
|
if type(node_in) == type(node):
|
|
|
|
@ -792,7 +792,7 @@ class TraceIndice(object):
|
|
|
|
|
self._add_dim(node_idx, i)
|
|
|
|
|
dim_from.reverse()
|
|
|
|
|
|
|
|
|
|
# inheirt indice from current node
|
|
|
|
|
# inherit indice from current node
|
|
|
|
|
if len(dim_from) != 0 and len(dim_to) != 0:
|
|
|
|
|
if dim_diff == 1:
|
|
|
|
|
if origin_shape[dim_from[0]] == 1:
|
|
|
|
@ -852,7 +852,7 @@ class TraceIndice(object):
|
|
|
|
|
elif "split" == node_name:
|
|
|
|
|
self._assign_split_indice(node, idx)
|
|
|
|
|
elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]):
|
|
|
|
|
self._assgin_no_change_indice(node, idx)
|
|
|
|
|
self._assign_no_change_indice(node, idx)
|
|
|
|
|
elif "new_ones" == node_name:
|
|
|
|
|
self._assign_all_indice(node, idx)
|
|
|
|
|
elif "flatten" == node_name:
|
|
|
|
@ -914,7 +914,7 @@ class TraceIndice(object):
|
|
|
|
|
elif "conv2d" == node_name:
|
|
|
|
|
self._assign_conv2d_indice(node, idx)
|
|
|
|
|
elif "identity" == node_name:
|
|
|
|
|
self._assgin_no_change_indice(node, idx)
|
|
|
|
|
self._assign_no_change_indice(node, idx)
|
|
|
|
|
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]):
|
|
|
|
|
self._assign_elementwise_indice(node, idx)
|
|
|
|
|
else:
|
|
|
|
|