pull/2364/head
oahzxl 2 years ago
parent cb9817f75d
commit 1bb1f2ad89

@ -158,11 +158,11 @@ class SearchChunk(object):
end_trace = output_trace[end_idx] end_trace = output_trace[end_idx]
end_node = self.trace_indice.node_list[end_idx] end_node = self.trace_indice.node_list[end_idx]
chunk_infos = [] chunk_infos = []
for end_dim, _ in enumerate(end_trace["idx"]): for end_dim, _ in enumerate(end_trace["indice"]):
if len(start_traces) > 1: if len(start_traces) > 1:
continue continue
for start_node, start_trace in start_traces.items(): for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["idx"]): for start_dim, _ in enumerate(start_trace["indice"]):
# dim size cannot be 1 # dim size cannot be 1
if ( if (
get_node_shape(end_node)[end_dim] == 1 get_node_shape(end_node)[end_dim] == 1

@ -19,12 +19,12 @@ class TraceIndice(object):
for n in self.node_list: for n in self.node_list:
if get_node_shape(n) != None: if get_node_shape(n) != None:
cur_trace = { cur_trace = {
"idx": [None for _ in range(len(get_node_shape(n)))], "indice": [None for _ in range(len(get_node_shape(n)))],
"compute": [[] for _ in range(len(get_node_shape(n)))], "compute": [[] for _ in range(len(get_node_shape(n)))],
"source": [{} for _ in range(len(get_node_shape(n)))], "source": [{} for _ in range(len(get_node_shape(n)))],
} }
else: else:
cur_trace = {"idx": [], "compute": [], "source": []} cur_trace = {"indice": [], "compute": [], "source": []}
indice_trace_list.append(cur_trace) indice_trace_list.append(cur_trace)
return indice_trace_list return indice_trace_list
@ -39,12 +39,12 @@ class TraceIndice(object):
return self.indice_count return self.indice_count
def _del_dim(self, idx, dim_idx): def _del_dim(self, idx, dim_idx):
self.indice_trace_list[idx]["idx"].pop(dim_idx) self.indice_trace_list[idx]["indice"].pop(dim_idx)
self.indice_trace_list[idx]["compute"].pop(dim_idx) self.indice_trace_list[idx]["compute"].pop(dim_idx)
self.indice_trace_list[idx]["source"].pop(dim_idx) self.indice_trace_list[idx]["source"].pop(dim_idx)
def _add_dim(self, node_idx, dim_idx): def _add_dim(self, node_idx, dim_idx):
self.indice_trace_list[node_idx]["idx"].insert(dim_idx, self._add_indice()) self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice())
self.indice_trace_list[node_idx]["compute"].insert(dim_idx, []) self.indice_trace_list[node_idx]["compute"].insert(dim_idx, [])
self.indice_trace_list[node_idx]["source"].insert(dim_idx, {}) self.indice_trace_list[node_idx]["source"].insert(dim_idx, {})
@ -58,7 +58,7 @@ class TraceIndice(object):
node_to_dim = self._transform_indice(node_to, node_to_dim) node_to_dim = self._transform_indice(node_to, node_to_dim)
node_from_trace = self._find_trace_from_node(node_from) node_from_trace = self._find_trace_from_node(node_from)
node_to_trace = self._find_trace_from_node(node_to) node_to_trace = self._find_trace_from_node(node_to)
node_to_trace["idx"][node_to_dim] = node_from_trace["idx"][node_from_dim] node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
node_to_trace["compute"][node_to_dim] = copy.deepcopy( node_to_trace["compute"][node_to_dim] = copy.deepcopy(
node_from_trace["compute"][node_from_dim] node_from_trace["compute"][node_from_dim]
) )
@ -181,7 +181,7 @@ class TraceIndice(object):
idx (list): idx of the node idx (list): idx of the node
""" """
node_idx = find_idx_by_name(node.name, self.node_list) node_idx = find_idx_by_name(node.name, self.node_list)
return self.indice_trace_list[node_idx]["idx"] return self.indice_trace_list[node_idx]["indice"]
def _find_compute_trace_from_node(self, node): def _find_compute_trace_from_node(self, node):
""" """
@ -195,7 +195,7 @@ class TraceIndice(object):
node_idx = find_idx_by_name(node.name, self.node_list) node_idx = find_idx_by_name(node.name, self.node_list)
return self.indice_trace_list[node_idx]["compute"] return self.indice_trace_list[node_idx]["compute"]
def _assign_index_as_input(self, node, node_idx, input_node=None): def _assign_indice_as_input(self, node, node_idx, input_node=None):
""" """
Assign node's trace as its input node. Assign node's trace as its input node.
@ -206,10 +206,10 @@ class TraceIndice(object):
if input_node == None: if input_node == None:
input_node = node.args[0] input_node = node.args[0]
input_node_idx = find_idx_by_name(input_node.name, self.node_list) input_node_idx = find_idx_by_name(input_node.name, self.node_list)
input_node_idx_trace = self.indice_trace_list[input_node_idx]["idx"] input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"]
new_idx_trace = copy.deepcopy(input_node_idx_trace) new_idx_trace = copy.deepcopy(input_node_idx_trace)
self.indice_trace_list[node_idx]["idx"] = new_idx_trace self.indice_trace_list[node_idx]["indice"] = new_idx_trace
self._inherit_all_computation(input_node, node) self._inherit_all_computation(input_node, node)
@ -225,7 +225,7 @@ class TraceIndice(object):
new_trace = [] new_trace = []
for _ in shape: for _ in shape:
new_trace.append(self._add_indice()) new_trace.append(self._add_indice())
self.indice_trace_list[node_idx]["idx"] = new_trace self.indice_trace_list[node_idx]["indice"] = new_trace
def _assign_transpose_indice(self, node, node_idx): def _assign_transpose_indice(self, node, node_idx):
""" """
@ -240,7 +240,7 @@ class TraceIndice(object):
input_node = node.args[0] input_node = node.args[0]
tranpose_dim = node.args[1:] tranpose_dim = node.args[1:]
self._assign_index_as_input(node, node_idx, input_node) self._assign_indice_as_input(node, node_idx, input_node)
self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0]) self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1]) self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
@ -257,7 +257,7 @@ class TraceIndice(object):
permute_dim = node.args[1:] permute_dim = node.args[1:]
input_node = node.args[0] input_node = node.args[0]
self._assign_index_as_input(node, node_idx, input_node) self._assign_indice_as_input(node, node_idx, input_node)
for idx, d in enumerate(permute_dim): for idx, d in enumerate(permute_dim):
self._inherit_indice(input_node, d, node, idx) self._inherit_indice(input_node, d, node, idx)
@ -278,7 +278,7 @@ class TraceIndice(object):
else: else:
input_node, weight, bias = node.args input_node, weight, bias = node.args
self._assign_index_as_input(node, node_idx) self._assign_indice_as_input(node, node_idx)
self._inherit_indice(weight, 1, node, -1) self._inherit_indice(weight, 1, node, -1)
self._mark_computation(node, node_idx, [-1]) self._mark_computation(node, node_idx, [-1])
@ -301,7 +301,7 @@ class TraceIndice(object):
matmul_left, matmul_right = node.args matmul_left, matmul_right = node.args
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right)) assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
self._assign_index_as_input(node, node_idx, matmul_left) self._assign_indice_as_input(node, node_idx, matmul_left)
self._inherit_indice(matmul_right, -1, node, -1) self._inherit_indice(matmul_right, -1, node, -1)
self._mark_computation_from_node(matmul_right, node, [-1, -2]) self._mark_computation_from_node(matmul_right, node, [-1, -2])
@ -318,7 +318,7 @@ class TraceIndice(object):
node (node) node (node)
node_idx (int) node_idx (int)
""" """
self._assign_index_as_input(node, idx) self._assign_indice_as_input(node, idx)
self._mark_computation(node, idx, [-1]) self._mark_computation(node, idx, [-1])
def _assign_elementwise_indice(self, node, idx): def _assign_elementwise_indice(self, node, idx):
@ -331,7 +331,7 @@ class TraceIndice(object):
node (node) node (node)
node_idx (int) node_idx (int)
""" """
self._assign_index_as_input(node, idx) self._assign_indice_as_input(node, idx)
nodes_in = [] nodes_in = []
for node_in in node.args: for node_in in node.args:
if type(node_in) == type(node): if type(node_in) == type(node):
@ -346,7 +346,7 @@ class TraceIndice(object):
self._mark_indice_equal(nodes_in[0], i, nodes_in[1], i) self._mark_indice_equal(nodes_in[0], i, nodes_in[1], i)
def _assgin_no_change_indice(self, node, idx): def _assgin_no_change_indice(self, node, idx):
self._assign_index_as_input(node, idx) self._assign_indice_as_input(node, idx)
for node_in in node.args: for node_in in node.args:
if type(node_in) == type(node): if type(node_in) == type(node):
self._mark_computation_from_node(node_in, node) self._mark_computation_from_node(node_in, node)
@ -398,7 +398,7 @@ class TraceIndice(object):
node (node) node (node)
node_idx (int) node_idx (int)
""" """
self._assign_index_as_input(node, idx) self._assign_indice_as_input(node, idx)
self._mark_computation(node, idx, [node.kwargs["dim"]]) self._mark_computation(node, idx, [node.kwargs["dim"]])
def _assign_unsqueeze_indice(self, node, node_idx): def _assign_unsqueeze_indice(self, node, node_idx):
@ -411,7 +411,7 @@ class TraceIndice(object):
node_idx (int) node_idx (int)
""" """
self._del_dim(node_idx, -1) self._del_dim(node_idx, -1)
self._assign_index_as_input(node, node_idx) self._assign_indice_as_input(node, node_idx)
self._add_dim(node_idx, node.args[1]) self._add_dim(node_idx, node.args[1])
def _assign_dropout_indice(self, node, node_idx): def _assign_dropout_indice(self, node, node_idx):
@ -423,7 +423,7 @@ class TraceIndice(object):
node (node) node (node)
node_idx (int) node_idx (int)
""" """
self._assign_index_as_input(node, node_idx) self._assign_indice_as_input(node, node_idx)
def _assign_ones_like_indice(self, node, node_idx): def _assign_ones_like_indice(self, node, node_idx):
""" """
@ -497,7 +497,7 @@ class TraceIndice(object):
# get new index # get new index
origin_trace = self._find_indice_trace_from_node(origin_node) origin_trace = self._find_indice_trace_from_node(origin_node)
self._assign_index_as_input(node, node_idx, origin_node) self._assign_indice_as_input(node, node_idx, origin_node)
dim_from.reverse() dim_from.reverse()
for i in dim_from: for i in dim_from:
self._del_dim(node_idx, i) self._del_dim(node_idx, i)
@ -516,7 +516,7 @@ class TraceIndice(object):
view_dict = { view_dict = {
"idx_from": [origin_trace[i] for i in dim_from], "idx_from": [origin_trace[i] for i in dim_from],
"dim_from": dim_from, "dim_from": dim_from,
"idx_to": [self.indice_trace_list[node_idx]["idx"][i] for i in dim_to], "idx_to": [self.indice_trace_list[node_idx]["indice"][i] for i in dim_to],
"dim_to": dim_to, "dim_to": dim_to,
} }
self.indice_view_list[node] = view_dict self.indice_view_list[node] = view_dict
@ -528,9 +528,9 @@ class TraceIndice(object):
merge_to = min(idx) merge_to = min(idx)
merge_from = max(idx) merge_from = max(idx)
for trace in self.indice_trace_list: for trace in self.indice_trace_list:
if merge_from in trace["idx"]: if merge_from in trace["indice"]:
trace["idx"] = [ trace["indice"] = [
merge_to if i == merge_from else i for i in trace["idx"] merge_to if i == merge_from else i for i in trace["indice"]
] ]
def trace_index(self): def trace_index(self):

Loading…
Cancel
Save