|
|
|
@ -19,12 +19,12 @@ class TraceIndice(object):
|
|
|
|
|
for n in self.node_list: |
|
|
|
|
if get_node_shape(n) != None: |
|
|
|
|
cur_trace = { |
|
|
|
|
"idx": [None for _ in range(len(get_node_shape(n)))], |
|
|
|
|
"indice": [None for _ in range(len(get_node_shape(n)))], |
|
|
|
|
"compute": [[] for _ in range(len(get_node_shape(n)))], |
|
|
|
|
"source": [{} for _ in range(len(get_node_shape(n)))], |
|
|
|
|
} |
|
|
|
|
else: |
|
|
|
|
cur_trace = {"idx": [], "compute": [], "source": []} |
|
|
|
|
cur_trace = {"indice": [], "compute": [], "source": []} |
|
|
|
|
indice_trace_list.append(cur_trace) |
|
|
|
|
return indice_trace_list |
|
|
|
|
|
|
|
|
@ -39,12 +39,12 @@ class TraceIndice(object):
|
|
|
|
|
return self.indice_count |
|
|
|
|
|
|
|
|
|
def _del_dim(self, idx, dim_idx): |
|
|
|
|
self.indice_trace_list[idx]["idx"].pop(dim_idx) |
|
|
|
|
self.indice_trace_list[idx]["indice"].pop(dim_idx) |
|
|
|
|
self.indice_trace_list[idx]["compute"].pop(dim_idx) |
|
|
|
|
self.indice_trace_list[idx]["source"].pop(dim_idx) |
|
|
|
|
|
|
|
|
|
def _add_dim(self, node_idx, dim_idx): |
|
|
|
|
self.indice_trace_list[node_idx]["idx"].insert(dim_idx, self._add_indice()) |
|
|
|
|
self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice()) |
|
|
|
|
self.indice_trace_list[node_idx]["compute"].insert(dim_idx, []) |
|
|
|
|
self.indice_trace_list[node_idx]["source"].insert(dim_idx, {}) |
|
|
|
|
|
|
|
|
@ -58,7 +58,7 @@ class TraceIndice(object):
|
|
|
|
|
node_to_dim = self._transform_indice(node_to, node_to_dim) |
|
|
|
|
node_from_trace = self._find_trace_from_node(node_from) |
|
|
|
|
node_to_trace = self._find_trace_from_node(node_to) |
|
|
|
|
node_to_trace["idx"][node_to_dim] = node_from_trace["idx"][node_from_dim] |
|
|
|
|
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim] |
|
|
|
|
node_to_trace["compute"][node_to_dim] = copy.deepcopy( |
|
|
|
|
node_from_trace["compute"][node_from_dim] |
|
|
|
|
) |
|
|
|
@ -181,7 +181,7 @@ class TraceIndice(object):
|
|
|
|
|
idx (list): idx of the node |
|
|
|
|
""" |
|
|
|
|
node_idx = find_idx_by_name(node.name, self.node_list) |
|
|
|
|
return self.indice_trace_list[node_idx]["idx"] |
|
|
|
|
return self.indice_trace_list[node_idx]["indice"] |
|
|
|
|
|
|
|
|
|
def _find_compute_trace_from_node(self, node): |
|
|
|
|
""" |
|
|
|
@ -195,7 +195,7 @@ class TraceIndice(object):
|
|
|
|
|
node_idx = find_idx_by_name(node.name, self.node_list) |
|
|
|
|
return self.indice_trace_list[node_idx]["compute"] |
|
|
|
|
|
|
|
|
|
def _assign_index_as_input(self, node, node_idx, input_node=None): |
|
|
|
|
def _assign_indice_as_input(self, node, node_idx, input_node=None): |
|
|
|
|
""" |
|
|
|
|
Assign node's trace as its input node. |
|
|
|
|
|
|
|
|
@ -206,10 +206,10 @@ class TraceIndice(object):
|
|
|
|
|
if input_node == None: |
|
|
|
|
input_node = node.args[0] |
|
|
|
|
input_node_idx = find_idx_by_name(input_node.name, self.node_list) |
|
|
|
|
input_node_idx_trace = self.indice_trace_list[input_node_idx]["idx"] |
|
|
|
|
input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"] |
|
|
|
|
|
|
|
|
|
new_idx_trace = copy.deepcopy(input_node_idx_trace) |
|
|
|
|
self.indice_trace_list[node_idx]["idx"] = new_idx_trace |
|
|
|
|
self.indice_trace_list[node_idx]["indice"] = new_idx_trace |
|
|
|
|
|
|
|
|
|
self._inherit_all_computation(input_node, node) |
|
|
|
|
|
|
|
|
@ -225,7 +225,7 @@ class TraceIndice(object):
|
|
|
|
|
new_trace = [] |
|
|
|
|
for _ in shape: |
|
|
|
|
new_trace.append(self._add_indice()) |
|
|
|
|
self.indice_trace_list[node_idx]["idx"] = new_trace |
|
|
|
|
self.indice_trace_list[node_idx]["indice"] = new_trace |
|
|
|
|
|
|
|
|
|
def _assign_transpose_indice(self, node, node_idx): |
|
|
|
|
""" |
|
|
|
@ -240,7 +240,7 @@ class TraceIndice(object):
|
|
|
|
|
input_node = node.args[0] |
|
|
|
|
tranpose_dim = node.args[1:] |
|
|
|
|
|
|
|
|
|
self._assign_index_as_input(node, node_idx, input_node) |
|
|
|
|
self._assign_indice_as_input(node, node_idx, input_node) |
|
|
|
|
self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0]) |
|
|
|
|
self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1]) |
|
|
|
|
|
|
|
|
@ -257,7 +257,7 @@ class TraceIndice(object):
|
|
|
|
|
permute_dim = node.args[1:] |
|
|
|
|
input_node = node.args[0] |
|
|
|
|
|
|
|
|
|
self._assign_index_as_input(node, node_idx, input_node) |
|
|
|
|
self._assign_indice_as_input(node, node_idx, input_node) |
|
|
|
|
for idx, d in enumerate(permute_dim): |
|
|
|
|
self._inherit_indice(input_node, d, node, idx) |
|
|
|
|
|
|
|
|
@ -278,7 +278,7 @@ class TraceIndice(object):
|
|
|
|
|
else: |
|
|
|
|
input_node, weight, bias = node.args |
|
|
|
|
|
|
|
|
|
self._assign_index_as_input(node, node_idx) |
|
|
|
|
self._assign_indice_as_input(node, node_idx) |
|
|
|
|
self._inherit_indice(weight, 1, node, -1) |
|
|
|
|
|
|
|
|
|
self._mark_computation(node, node_idx, [-1]) |
|
|
|
@ -301,7 +301,7 @@ class TraceIndice(object):
|
|
|
|
|
matmul_left, matmul_right = node.args |
|
|
|
|
|
|
|
|
|
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right)) |
|
|
|
|
self._assign_index_as_input(node, node_idx, matmul_left) |
|
|
|
|
self._assign_indice_as_input(node, node_idx, matmul_left) |
|
|
|
|
self._inherit_indice(matmul_right, -1, node, -1) |
|
|
|
|
|
|
|
|
|
self._mark_computation_from_node(matmul_right, node, [-1, -2]) |
|
|
|
@ -318,7 +318,7 @@ class TraceIndice(object):
|
|
|
|
|
node (node) |
|
|
|
|
node_idx (int) |
|
|
|
|
""" |
|
|
|
|
self._assign_index_as_input(node, idx) |
|
|
|
|
self._assign_indice_as_input(node, idx) |
|
|
|
|
self._mark_computation(node, idx, [-1]) |
|
|
|
|
|
|
|
|
|
def _assign_elementwise_indice(self, node, idx): |
|
|
|
@ -331,7 +331,7 @@ class TraceIndice(object):
|
|
|
|
|
node (node) |
|
|
|
|
node_idx (int) |
|
|
|
|
""" |
|
|
|
|
self._assign_index_as_input(node, idx) |
|
|
|
|
self._assign_indice_as_input(node, idx) |
|
|
|
|
nodes_in = [] |
|
|
|
|
for node_in in node.args: |
|
|
|
|
if type(node_in) == type(node): |
|
|
|
@ -346,7 +346,7 @@ class TraceIndice(object):
|
|
|
|
|
self._mark_indice_equal(nodes_in[0], i, nodes_in[1], i) |
|
|
|
|
|
|
|
|
|
def _assgin_no_change_indice(self, node, idx): |
|
|
|
|
self._assign_index_as_input(node, idx) |
|
|
|
|
self._assign_indice_as_input(node, idx) |
|
|
|
|
for node_in in node.args: |
|
|
|
|
if type(node_in) == type(node): |
|
|
|
|
self._mark_computation_from_node(node_in, node) |
|
|
|
@ -398,7 +398,7 @@ class TraceIndice(object):
|
|
|
|
|
node (node) |
|
|
|
|
node_idx (int) |
|
|
|
|
""" |
|
|
|
|
self._assign_index_as_input(node, idx) |
|
|
|
|
self._assign_indice_as_input(node, idx) |
|
|
|
|
self._mark_computation(node, idx, [node.kwargs["dim"]]) |
|
|
|
|
|
|
|
|
|
def _assign_unsqueeze_indice(self, node, node_idx): |
|
|
|
@ -411,7 +411,7 @@ class TraceIndice(object):
|
|
|
|
|
node_idx (int) |
|
|
|
|
""" |
|
|
|
|
self._del_dim(node_idx, -1) |
|
|
|
|
self._assign_index_as_input(node, node_idx) |
|
|
|
|
self._assign_indice_as_input(node, node_idx) |
|
|
|
|
self._add_dim(node_idx, node.args[1]) |
|
|
|
|
|
|
|
|
|
def _assign_dropout_indice(self, node, node_idx): |
|
|
|
@ -423,7 +423,7 @@ class TraceIndice(object):
|
|
|
|
|
node (node) |
|
|
|
|
node_idx (int) |
|
|
|
|
""" |
|
|
|
|
self._assign_index_as_input(node, node_idx) |
|
|
|
|
self._assign_indice_as_input(node, node_idx) |
|
|
|
|
|
|
|
|
|
def _assign_ones_like_indice(self, node, node_idx): |
|
|
|
|
""" |
|
|
|
@ -497,7 +497,7 @@ class TraceIndice(object):
|
|
|
|
|
|
|
|
|
|
# get new index |
|
|
|
|
origin_trace = self._find_indice_trace_from_node(origin_node) |
|
|
|
|
self._assign_index_as_input(node, node_idx, origin_node) |
|
|
|
|
self._assign_indice_as_input(node, node_idx, origin_node) |
|
|
|
|
dim_from.reverse() |
|
|
|
|
for i in dim_from: |
|
|
|
|
self._del_dim(node_idx, i) |
|
|
|
@ -516,7 +516,7 @@ class TraceIndice(object):
|
|
|
|
|
view_dict = { |
|
|
|
|
"idx_from": [origin_trace[i] for i in dim_from], |
|
|
|
|
"dim_from": dim_from, |
|
|
|
|
"idx_to": [self.indice_trace_list[node_idx]["idx"][i] for i in dim_to], |
|
|
|
|
"idx_to": [self.indice_trace_list[node_idx]["indice"][i] for i in dim_to], |
|
|
|
|
"dim_to": dim_to, |
|
|
|
|
} |
|
|
|
|
self.indice_view_list[node] = view_dict |
|
|
|
@ -528,9 +528,9 @@ class TraceIndice(object):
|
|
|
|
|
merge_to = min(idx) |
|
|
|
|
merge_from = max(idx) |
|
|
|
|
for trace in self.indice_trace_list: |
|
|
|
|
if merge_from in trace["idx"]: |
|
|
|
|
trace["idx"] = [ |
|
|
|
|
merge_to if i == merge_from else i for i in trace["idx"] |
|
|
|
|
if merge_from in trace["indice"]: |
|
|
|
|
trace["indice"] = [ |
|
|
|
|
merge_to if i == merge_from else i for i in trace["indice"] |
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
def trace_index(self): |
|
|
|
|