|
|
@ -19,12 +19,12 @@ class TraceIndice(object):
|
|
|
|
for n in self.node_list:
|
|
|
|
for n in self.node_list:
|
|
|
|
if get_node_shape(n) != None:
|
|
|
|
if get_node_shape(n) != None:
|
|
|
|
cur_trace = {
|
|
|
|
cur_trace = {
|
|
|
|
"idx": [None for _ in range(len(get_node_shape(n)))],
|
|
|
|
"indice": [None for _ in range(len(get_node_shape(n)))],
|
|
|
|
"compute": [[] for _ in range(len(get_node_shape(n)))],
|
|
|
|
"compute": [[] for _ in range(len(get_node_shape(n)))],
|
|
|
|
"source": [{} for _ in range(len(get_node_shape(n)))],
|
|
|
|
"source": [{} for _ in range(len(get_node_shape(n)))],
|
|
|
|
}
|
|
|
|
}
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
cur_trace = {"idx": [], "compute": [], "source": []}
|
|
|
|
cur_trace = {"indice": [], "compute": [], "source": []}
|
|
|
|
indice_trace_list.append(cur_trace)
|
|
|
|
indice_trace_list.append(cur_trace)
|
|
|
|
return indice_trace_list
|
|
|
|
return indice_trace_list
|
|
|
|
|
|
|
|
|
|
|
@ -39,12 +39,12 @@ class TraceIndice(object):
|
|
|
|
return self.indice_count
|
|
|
|
return self.indice_count
|
|
|
|
|
|
|
|
|
|
|
|
def _del_dim(self, idx, dim_idx):
|
|
|
|
def _del_dim(self, idx, dim_idx):
|
|
|
|
self.indice_trace_list[idx]["idx"].pop(dim_idx)
|
|
|
|
self.indice_trace_list[idx]["indice"].pop(dim_idx)
|
|
|
|
self.indice_trace_list[idx]["compute"].pop(dim_idx)
|
|
|
|
self.indice_trace_list[idx]["compute"].pop(dim_idx)
|
|
|
|
self.indice_trace_list[idx]["source"].pop(dim_idx)
|
|
|
|
self.indice_trace_list[idx]["source"].pop(dim_idx)
|
|
|
|
|
|
|
|
|
|
|
|
def _add_dim(self, node_idx, dim_idx):
|
|
|
|
def _add_dim(self, node_idx, dim_idx):
|
|
|
|
self.indice_trace_list[node_idx]["idx"].insert(dim_idx, self._add_indice())
|
|
|
|
self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice())
|
|
|
|
self.indice_trace_list[node_idx]["compute"].insert(dim_idx, [])
|
|
|
|
self.indice_trace_list[node_idx]["compute"].insert(dim_idx, [])
|
|
|
|
self.indice_trace_list[node_idx]["source"].insert(dim_idx, {})
|
|
|
|
self.indice_trace_list[node_idx]["source"].insert(dim_idx, {})
|
|
|
|
|
|
|
|
|
|
|
@ -58,7 +58,7 @@ class TraceIndice(object):
|
|
|
|
node_to_dim = self._transform_indice(node_to, node_to_dim)
|
|
|
|
node_to_dim = self._transform_indice(node_to, node_to_dim)
|
|
|
|
node_from_trace = self._find_trace_from_node(node_from)
|
|
|
|
node_from_trace = self._find_trace_from_node(node_from)
|
|
|
|
node_to_trace = self._find_trace_from_node(node_to)
|
|
|
|
node_to_trace = self._find_trace_from_node(node_to)
|
|
|
|
node_to_trace["idx"][node_to_dim] = node_from_trace["idx"][node_from_dim]
|
|
|
|
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
|
|
|
|
node_to_trace["compute"][node_to_dim] = copy.deepcopy(
|
|
|
|
node_to_trace["compute"][node_to_dim] = copy.deepcopy(
|
|
|
|
node_from_trace["compute"][node_from_dim]
|
|
|
|
node_from_trace["compute"][node_from_dim]
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -181,7 +181,7 @@ class TraceIndice(object):
|
|
|
|
idx (list): idx of the node
|
|
|
|
idx (list): idx of the node
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
node_idx = find_idx_by_name(node.name, self.node_list)
|
|
|
|
node_idx = find_idx_by_name(node.name, self.node_list)
|
|
|
|
return self.indice_trace_list[node_idx]["idx"]
|
|
|
|
return self.indice_trace_list[node_idx]["indice"]
|
|
|
|
|
|
|
|
|
|
|
|
def _find_compute_trace_from_node(self, node):
|
|
|
|
def _find_compute_trace_from_node(self, node):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -195,7 +195,7 @@ class TraceIndice(object):
|
|
|
|
node_idx = find_idx_by_name(node.name, self.node_list)
|
|
|
|
node_idx = find_idx_by_name(node.name, self.node_list)
|
|
|
|
return self.indice_trace_list[node_idx]["compute"]
|
|
|
|
return self.indice_trace_list[node_idx]["compute"]
|
|
|
|
|
|
|
|
|
|
|
|
def _assign_index_as_input(self, node, node_idx, input_node=None):
|
|
|
|
def _assign_indice_as_input(self, node, node_idx, input_node=None):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Assign node's trace as its input node.
|
|
|
|
Assign node's trace as its input node.
|
|
|
|
|
|
|
|
|
|
|
@ -206,10 +206,10 @@ class TraceIndice(object):
|
|
|
|
if input_node == None:
|
|
|
|
if input_node == None:
|
|
|
|
input_node = node.args[0]
|
|
|
|
input_node = node.args[0]
|
|
|
|
input_node_idx = find_idx_by_name(input_node.name, self.node_list)
|
|
|
|
input_node_idx = find_idx_by_name(input_node.name, self.node_list)
|
|
|
|
input_node_idx_trace = self.indice_trace_list[input_node_idx]["idx"]
|
|
|
|
input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"]
|
|
|
|
|
|
|
|
|
|
|
|
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
|
|
|
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
|
|
|
self.indice_trace_list[node_idx]["idx"] = new_idx_trace
|
|
|
|
self.indice_trace_list[node_idx]["indice"] = new_idx_trace
|
|
|
|
|
|
|
|
|
|
|
|
self._inherit_all_computation(input_node, node)
|
|
|
|
self._inherit_all_computation(input_node, node)
|
|
|
|
|
|
|
|
|
|
|
@ -225,7 +225,7 @@ class TraceIndice(object):
|
|
|
|
new_trace = []
|
|
|
|
new_trace = []
|
|
|
|
for _ in shape:
|
|
|
|
for _ in shape:
|
|
|
|
new_trace.append(self._add_indice())
|
|
|
|
new_trace.append(self._add_indice())
|
|
|
|
self.indice_trace_list[node_idx]["idx"] = new_trace
|
|
|
|
self.indice_trace_list[node_idx]["indice"] = new_trace
|
|
|
|
|
|
|
|
|
|
|
|
def _assign_transpose_indice(self, node, node_idx):
|
|
|
|
def _assign_transpose_indice(self, node, node_idx):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -240,7 +240,7 @@ class TraceIndice(object):
|
|
|
|
input_node = node.args[0]
|
|
|
|
input_node = node.args[0]
|
|
|
|
tranpose_dim = node.args[1:]
|
|
|
|
tranpose_dim = node.args[1:]
|
|
|
|
|
|
|
|
|
|
|
|
self._assign_index_as_input(node, node_idx, input_node)
|
|
|
|
self._assign_indice_as_input(node, node_idx, input_node)
|
|
|
|
self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
|
|
|
|
self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
|
|
|
|
self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
|
|
|
|
self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
|
|
|
|
|
|
|
|
|
|
|
@ -257,7 +257,7 @@ class TraceIndice(object):
|
|
|
|
permute_dim = node.args[1:]
|
|
|
|
permute_dim = node.args[1:]
|
|
|
|
input_node = node.args[0]
|
|
|
|
input_node = node.args[0]
|
|
|
|
|
|
|
|
|
|
|
|
self._assign_index_as_input(node, node_idx, input_node)
|
|
|
|
self._assign_indice_as_input(node, node_idx, input_node)
|
|
|
|
for idx, d in enumerate(permute_dim):
|
|
|
|
for idx, d in enumerate(permute_dim):
|
|
|
|
self._inherit_indice(input_node, d, node, idx)
|
|
|
|
self._inherit_indice(input_node, d, node, idx)
|
|
|
|
|
|
|
|
|
|
|
@ -278,7 +278,7 @@ class TraceIndice(object):
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
input_node, weight, bias = node.args
|
|
|
|
input_node, weight, bias = node.args
|
|
|
|
|
|
|
|
|
|
|
|
self._assign_index_as_input(node, node_idx)
|
|
|
|
self._assign_indice_as_input(node, node_idx)
|
|
|
|
self._inherit_indice(weight, 1, node, -1)
|
|
|
|
self._inherit_indice(weight, 1, node, -1)
|
|
|
|
|
|
|
|
|
|
|
|
self._mark_computation(node, node_idx, [-1])
|
|
|
|
self._mark_computation(node, node_idx, [-1])
|
|
|
@ -301,7 +301,7 @@ class TraceIndice(object):
|
|
|
|
matmul_left, matmul_right = node.args
|
|
|
|
matmul_left, matmul_right = node.args
|
|
|
|
|
|
|
|
|
|
|
|
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
|
|
|
|
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
|
|
|
|
self._assign_index_as_input(node, node_idx, matmul_left)
|
|
|
|
self._assign_indice_as_input(node, node_idx, matmul_left)
|
|
|
|
self._inherit_indice(matmul_right, -1, node, -1)
|
|
|
|
self._inherit_indice(matmul_right, -1, node, -1)
|
|
|
|
|
|
|
|
|
|
|
|
self._mark_computation_from_node(matmul_right, node, [-1, -2])
|
|
|
|
self._mark_computation_from_node(matmul_right, node, [-1, -2])
|
|
|
@ -318,7 +318,7 @@ class TraceIndice(object):
|
|
|
|
node (node)
|
|
|
|
node (node)
|
|
|
|
node_idx (int)
|
|
|
|
node_idx (int)
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
self._assign_index_as_input(node, idx)
|
|
|
|
self._assign_indice_as_input(node, idx)
|
|
|
|
self._mark_computation(node, idx, [-1])
|
|
|
|
self._mark_computation(node, idx, [-1])
|
|
|
|
|
|
|
|
|
|
|
|
def _assign_elementwise_indice(self, node, idx):
|
|
|
|
def _assign_elementwise_indice(self, node, idx):
|
|
|
@ -331,7 +331,7 @@ class TraceIndice(object):
|
|
|
|
node (node)
|
|
|
|
node (node)
|
|
|
|
node_idx (int)
|
|
|
|
node_idx (int)
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
self._assign_index_as_input(node, idx)
|
|
|
|
self._assign_indice_as_input(node, idx)
|
|
|
|
nodes_in = []
|
|
|
|
nodes_in = []
|
|
|
|
for node_in in node.args:
|
|
|
|
for node_in in node.args:
|
|
|
|
if type(node_in) == type(node):
|
|
|
|
if type(node_in) == type(node):
|
|
|
@ -346,7 +346,7 @@ class TraceIndice(object):
|
|
|
|
self._mark_indice_equal(nodes_in[0], i, nodes_in[1], i)
|
|
|
|
self._mark_indice_equal(nodes_in[0], i, nodes_in[1], i)
|
|
|
|
|
|
|
|
|
|
|
|
def _assgin_no_change_indice(self, node, idx):
|
|
|
|
def _assgin_no_change_indice(self, node, idx):
|
|
|
|
self._assign_index_as_input(node, idx)
|
|
|
|
self._assign_indice_as_input(node, idx)
|
|
|
|
for node_in in node.args:
|
|
|
|
for node_in in node.args:
|
|
|
|
if type(node_in) == type(node):
|
|
|
|
if type(node_in) == type(node):
|
|
|
|
self._mark_computation_from_node(node_in, node)
|
|
|
|
self._mark_computation_from_node(node_in, node)
|
|
|
@ -398,7 +398,7 @@ class TraceIndice(object):
|
|
|
|
node (node)
|
|
|
|
node (node)
|
|
|
|
node_idx (int)
|
|
|
|
node_idx (int)
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
self._assign_index_as_input(node, idx)
|
|
|
|
self._assign_indice_as_input(node, idx)
|
|
|
|
self._mark_computation(node, idx, [node.kwargs["dim"]])
|
|
|
|
self._mark_computation(node, idx, [node.kwargs["dim"]])
|
|
|
|
|
|
|
|
|
|
|
|
def _assign_unsqueeze_indice(self, node, node_idx):
|
|
|
|
def _assign_unsqueeze_indice(self, node, node_idx):
|
|
|
@ -411,7 +411,7 @@ class TraceIndice(object):
|
|
|
|
node_idx (int)
|
|
|
|
node_idx (int)
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
self._del_dim(node_idx, -1)
|
|
|
|
self._del_dim(node_idx, -1)
|
|
|
|
self._assign_index_as_input(node, node_idx)
|
|
|
|
self._assign_indice_as_input(node, node_idx)
|
|
|
|
self._add_dim(node_idx, node.args[1])
|
|
|
|
self._add_dim(node_idx, node.args[1])
|
|
|
|
|
|
|
|
|
|
|
|
def _assign_dropout_indice(self, node, node_idx):
|
|
|
|
def _assign_dropout_indice(self, node, node_idx):
|
|
|
@ -423,7 +423,7 @@ class TraceIndice(object):
|
|
|
|
node (node)
|
|
|
|
node (node)
|
|
|
|
node_idx (int)
|
|
|
|
node_idx (int)
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
self._assign_index_as_input(node, node_idx)
|
|
|
|
self._assign_indice_as_input(node, node_idx)
|
|
|
|
|
|
|
|
|
|
|
|
def _assign_ones_like_indice(self, node, node_idx):
|
|
|
|
def _assign_ones_like_indice(self, node, node_idx):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -497,7 +497,7 @@ class TraceIndice(object):
|
|
|
|
|
|
|
|
|
|
|
|
# get new index
|
|
|
|
# get new index
|
|
|
|
origin_trace = self._find_indice_trace_from_node(origin_node)
|
|
|
|
origin_trace = self._find_indice_trace_from_node(origin_node)
|
|
|
|
self._assign_index_as_input(node, node_idx, origin_node)
|
|
|
|
self._assign_indice_as_input(node, node_idx, origin_node)
|
|
|
|
dim_from.reverse()
|
|
|
|
dim_from.reverse()
|
|
|
|
for i in dim_from:
|
|
|
|
for i in dim_from:
|
|
|
|
self._del_dim(node_idx, i)
|
|
|
|
self._del_dim(node_idx, i)
|
|
|
@ -516,7 +516,7 @@ class TraceIndice(object):
|
|
|
|
view_dict = {
|
|
|
|
view_dict = {
|
|
|
|
"idx_from": [origin_trace[i] for i in dim_from],
|
|
|
|
"idx_from": [origin_trace[i] for i in dim_from],
|
|
|
|
"dim_from": dim_from,
|
|
|
|
"dim_from": dim_from,
|
|
|
|
"idx_to": [self.indice_trace_list[node_idx]["idx"][i] for i in dim_to],
|
|
|
|
"idx_to": [self.indice_trace_list[node_idx]["indice"][i] for i in dim_to],
|
|
|
|
"dim_to": dim_to,
|
|
|
|
"dim_to": dim_to,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
self.indice_view_list[node] = view_dict
|
|
|
|
self.indice_view_list[node] = view_dict
|
|
|
@ -528,9 +528,9 @@ class TraceIndice(object):
|
|
|
|
merge_to = min(idx)
|
|
|
|
merge_to = min(idx)
|
|
|
|
merge_from = max(idx)
|
|
|
|
merge_from = max(idx)
|
|
|
|
for trace in self.indice_trace_list:
|
|
|
|
for trace in self.indice_trace_list:
|
|
|
|
if merge_from in trace["idx"]:
|
|
|
|
if merge_from in trace["indice"]:
|
|
|
|
trace["idx"] = [
|
|
|
|
trace["indice"] = [
|
|
|
|
merge_to if i == merge_from else i for i in trace["idx"]
|
|
|
|
merge_to if i == merge_from else i for i in trace["indice"]
|
|
|
|
]
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
def trace_index(self):
|
|
|
|
def trace_index(self):
|
|
|
|