rename function from index to indice

pull/2364/head
oahzxl 2023-01-09 17:34:30 +08:00
parent 0ea903b94e
commit cb9817f75d
4 changed files with 91 additions and 91 deletions

View File

@ -6,7 +6,7 @@ class ReorderGraph(object):
def __init__(self, trace_indice: TraceIndice) -> None: def __init__(self, trace_indice: TraceIndice) -> None:
self.trace_indice = trace_indice self.trace_indice = trace_indice
self.all_reorder_map = { self.all_reorder_map = {
i: i for i in range(len(self.trace_indice.idx_trace_list)) i: i for i in range(len(self.trace_indice.indice_trace_list))
} }
def _get_reorder_map(self, chunk_info): def _get_reorder_map(self, chunk_info):
@ -60,18 +60,18 @@ class ReorderGraph(object):
def _reorder_idx_trace(self, reorder_map): def _reorder_idx_trace(self, reorder_map):
# reorder list # reorder list
new_idx_trace_list = [None for _ in range(len(self.trace_indice.idx_trace_list))] new_idx_trace_list = [None for _ in range(len(self.trace_indice.indice_trace_list))]
for old_idx, new_idx in reorder_map.items(): for old_idx, new_idx in reorder_map.items():
new_idx_trace_list[new_idx] = self.trace_indice.idx_trace_list[old_idx] new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx]
self.trace_indice.idx_trace_list = new_idx_trace_list self.trace_indice.indice_trace_list = new_idx_trace_list
# update compute # update compute
for idx_trace in self.trace_indice.idx_trace_list: for idx_trace in self.trace_indice.indice_trace_list:
compute = idx_trace["compute"] compute = idx_trace["compute"]
for dim_compute in compute: for dim_compute in compute:
for idx, i in enumerate(dim_compute): for idx, i in enumerate(dim_compute):
dim_compute[idx] = reorder_map[i] dim_compute[idx] = reorder_map[i]
# update source # update source
for idx_trace in self.trace_indice.idx_trace_list: for idx_trace in self.trace_indice.indice_trace_list:
source = idx_trace["source"] source = idx_trace["source"]
for dim_idx, dim_source in enumerate(source): for dim_idx, dim_source in enumerate(source):
new_dim_source = {} new_dim_source = {}

View File

@ -205,7 +205,7 @@ class SearchChunk(object):
possible_chunk_region (List) possible_chunk_region (List)
""" """
possible_chunk_region = [] possible_chunk_region = []
output_trace = copy.deepcopy(self.trace_indice.idx_trace_list) output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
input_trace = [] # trace of a node's input nodes input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.trace_indice.node_list): for _, n in enumerate(self.trace_indice.node_list):
cur_trace = {} cur_trace = {}

View File

@ -406,7 +406,7 @@ class TraceFlow(object):
for node in self.trace_indice.node_list[chunk_region[0] : chunk_region[1] + 1]: for node in self.trace_indice.node_list[chunk_region[0] : chunk_region[1] + 1]:
if any(i in node.name for i in ["reshape", "view"]): if any(i in node.name for i in ["reshape", "view"]):
reshape_args = node.args[1:] reshape_args = node.args[1:]
reshape_log = self.trace_indice.idx_view_list[node] reshape_log = self.trace_indice.indice_view_list[node]
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
reshape_size[node.name] = {} reshape_size[node.name] = {}
for reshape_arg_dim, reshape_arg in enumerate(reshape_args): for reshape_arg_dim, reshape_arg in enumerate(reshape_args):

View File

@ -9,13 +9,13 @@ from .utils import (
class TraceIndice(object): class TraceIndice(object):
def __init__(self, node_list) -> None: def __init__(self, node_list) -> None:
self.node_list = node_list self.node_list = node_list
self.idx_trace_list = self._init_idx_trace_list() self.indice_trace_list = self._init_indice_trace_list()
self.idx_trace_equal = [] self.indice_trace_equal = []
self.idx_view_list = {} self.indice_view_list = {}
self.idx_count = -1 self.indice_count = -1
def _init_idx_trace_list(self): def _init_indice_trace_list(self):
idx_trace_list = [] indice_trace_list = []
for n in self.node_list: for n in self.node_list:
if get_node_shape(n) != None: if get_node_shape(n) != None:
cur_trace = { cur_trace = {
@ -25,37 +25,37 @@ class TraceIndice(object):
} }
else: else:
cur_trace = {"idx": [], "compute": [], "source": []} cur_trace = {"idx": [], "compute": [], "source": []}
idx_trace_list.append(cur_trace) indice_trace_list.append(cur_trace)
return idx_trace_list return indice_trace_list
def _add_index(self): def _add_indice(self):
""" """
Update the count and return it. To record the idx number. Update the count and return it. To record the idx number.
Returns: Returns:
idx_count: int idx_count: int
""" """
self.idx_count += 1 self.indice_count += 1
return self.idx_count return self.indice_count
def _del_dim(self, idx, dim_idx): def _del_dim(self, idx, dim_idx):
self.idx_trace_list[idx]["idx"].pop(dim_idx) self.indice_trace_list[idx]["idx"].pop(dim_idx)
self.idx_trace_list[idx]["compute"].pop(dim_idx) self.indice_trace_list[idx]["compute"].pop(dim_idx)
self.idx_trace_list[idx]["source"].pop(dim_idx) self.indice_trace_list[idx]["source"].pop(dim_idx)
def _add_dim(self, node_idx, dim_idx): def _add_dim(self, node_idx, dim_idx):
self.idx_trace_list[node_idx]["idx"].insert(dim_idx, self._add_index()) self.indice_trace_list[node_idx]["idx"].insert(dim_idx, self._add_indice())
self.idx_trace_list[node_idx]["compute"].insert(dim_idx, []) self.indice_trace_list[node_idx]["compute"].insert(dim_idx, [])
self.idx_trace_list[node_idx]["source"].insert(dim_idx, {}) self.indice_trace_list[node_idx]["source"].insert(dim_idx, {})
def _transform_index(self, node, node_dim): def _transform_indice(self, node, node_dim):
node_idx = self._find_idx_trace_from_node(node) node_idx = self._find_indice_trace_from_node(node)
dims = list(range(len(node_idx))) dims = list(range(len(node_idx)))
return dims[node_dim] return dims[node_dim]
def _inherit_index(self, node_from, node_from_dim, node_to, node_to_dim): def _inherit_indice(self, node_from, node_from_dim, node_to, node_to_dim):
node_from_dim = self._transform_index(node_from, node_from_dim) node_from_dim = self._transform_indice(node_from, node_from_dim)
node_to_dim = self._transform_index(node_to, node_to_dim) node_to_dim = self._transform_indice(node_to, node_to_dim)
node_from_trace = self._find_trace_from_node(node_from) node_from_trace = self._find_trace_from_node(node_from)
node_to_trace = self._find_trace_from_node(node_to) node_to_trace = self._find_trace_from_node(node_to)
node_to_trace["idx"][node_to_dim] = node_from_trace["idx"][node_from_dim] node_to_trace["idx"][node_to_dim] = node_from_trace["idx"][node_from_dim]
@ -73,9 +73,9 @@ class TraceIndice(object):
node_to_compute[i] = copy.deepcopy(node_from_compute[i]) node_to_compute[i] = copy.deepcopy(node_from_compute[i])
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False): def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False):
node_from_dim = self._transform_index(node_from, node_from_dim) node_from_dim = self._transform_indice(node_from, node_from_dim)
node_from_trace_source = self._find_source_trace_from_node(node_from) node_from_trace_source = self._find_source_trace_from_node(node_from)
node_to_dim = self._transform_index(node_to, node_to_dim) node_to_dim = self._transform_indice(node_to, node_to_dim)
node_to_trace_source = self._find_source_trace_from_node(node_to) node_to_trace_source = self._find_source_trace_from_node(node_to)
node_from_idx = find_idx_by_name(node_from.name, self.node_list) node_from_idx = find_idx_by_name(node_from.name, self.node_list)
if init: if init:
@ -99,19 +99,19 @@ class TraceIndice(object):
if exclude == None: if exclude == None:
exclude = [] exclude = []
else: else:
exclude = [self._transform_index(node_to, i) for i in exclude] exclude = [self._transform_indice(node_to, i) for i in exclude]
node_from_compute = self._find_compute_trace_from_node(node_from) node_from_compute = self._find_compute_trace_from_node(node_from)
node_to_compute = self._find_compute_trace_from_node(node_to) node_to_compute = self._find_compute_trace_from_node(node_to)
# assert len(node_from_compute) == len(node_to_compute) # assert len(node_from_compute) == len(node_to_compute)
for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1): for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1):
if self._transform_index(node_to, i) in exclude: if self._transform_indice(node_to, i) in exclude:
continue continue
self._add_source(node_from, i, node_to, i) self._add_source(node_from, i, node_to, i)
for j in node_from_compute[i]: for j in node_from_compute[i]:
if j not in node_to_compute[i]: if j not in node_to_compute[i]:
node_to_compute[i].append(j) node_to_compute[i].append(j)
def _mark_idx_equal(self, node1, dim1, node2, dim2): def _mark_indice_equal(self, node1, dim1, node2, dim2):
""" """
Mark 2 index to be equal. Mark 2 index to be equal.
@ -140,8 +140,8 @@ class TraceIndice(object):
dims = list(range(len(get_node_shape(node)))) dims = list(range(len(get_node_shape(node))))
for d in dim: for d in dim:
cur_dim = dims[d] cur_dim = dims[d]
if idx not in self.idx_trace_list[idx]["compute"][cur_dim]: if idx not in self.indice_trace_list[idx]["compute"][cur_dim]:
self.idx_trace_list[idx]["compute"][cur_dim].append(idx) self.indice_trace_list[idx]["compute"][cur_dim].append(idx)
def _find_trace_from_node(self, node): def _find_trace_from_node(self, node):
""" """
@ -154,7 +154,7 @@ class TraceIndice(object):
compute (list): computed idx of the node. compute (list): computed idx of the node.
""" """
node_idx = find_idx_by_name(node.name, self.node_list) node_idx = find_idx_by_name(node.name, self.node_list)
node_dict = self.idx_trace_list[node_idx] node_dict = self.indice_trace_list[node_idx]
return node_dict return node_dict
def _find_source_trace_from_node(self, node): def _find_source_trace_from_node(self, node):
@ -168,10 +168,10 @@ class TraceIndice(object):
compute (list): computed idx of the node. compute (list): computed idx of the node.
""" """
node_idx = find_idx_by_name(node.name, self.node_list) node_idx = find_idx_by_name(node.name, self.node_list)
node_dict = self.idx_trace_list[node_idx] node_dict = self.indice_trace_list[node_idx]
return node_dict["source"] return node_dict["source"]
def _find_idx_trace_from_node(self, node): def _find_indice_trace_from_node(self, node):
""" """
Find node idx trace by the node. Find node idx trace by the node.
@ -181,7 +181,7 @@ class TraceIndice(object):
idx (list): idx of the node idx (list): idx of the node
""" """
node_idx = find_idx_by_name(node.name, self.node_list) node_idx = find_idx_by_name(node.name, self.node_list)
return self.idx_trace_list[node_idx]["idx"] return self.indice_trace_list[node_idx]["idx"]
def _find_compute_trace_from_node(self, node): def _find_compute_trace_from_node(self, node):
""" """
@ -193,7 +193,7 @@ class TraceIndice(object):
compute (list): computed idx of the node. compute (list): computed idx of the node.
""" """
node_idx = find_idx_by_name(node.name, self.node_list) node_idx = find_idx_by_name(node.name, self.node_list)
return self.idx_trace_list[node_idx]["compute"] return self.indice_trace_list[node_idx]["compute"]
def _assign_index_as_input(self, node, node_idx, input_node=None): def _assign_index_as_input(self, node, node_idx, input_node=None):
""" """
@ -206,14 +206,14 @@ class TraceIndice(object):
if input_node == None: if input_node == None:
input_node = node.args[0] input_node = node.args[0]
input_node_idx = find_idx_by_name(input_node.name, self.node_list) input_node_idx = find_idx_by_name(input_node.name, self.node_list)
input_node_idx_trace = self.idx_trace_list[input_node_idx]["idx"] input_node_idx_trace = self.indice_trace_list[input_node_idx]["idx"]
new_idx_trace = copy.deepcopy(input_node_idx_trace) new_idx_trace = copy.deepcopy(input_node_idx_trace)
self.idx_trace_list[node_idx]["idx"] = new_idx_trace self.indice_trace_list[node_idx]["idx"] = new_idx_trace
self._inherit_all_computation(input_node, node) self._inherit_all_computation(input_node, node)
def _assign_all_index(self, node, node_idx): def _assign_all_indice(self, node, node_idx):
""" """
Add new index for all node's dims. Add new index for all node's dims.
@ -224,10 +224,10 @@ class TraceIndice(object):
shape = node.meta["tensor_meta"].shape shape = node.meta["tensor_meta"].shape
new_trace = [] new_trace = []
for _ in shape: for _ in shape:
new_trace.append(self._add_index()) new_trace.append(self._add_indice())
self.idx_trace_list[node_idx]["idx"] = new_trace self.indice_trace_list[node_idx]["idx"] = new_trace
def _assign_transpose_index(self, node, node_idx): def _assign_transpose_indice(self, node, node_idx):
""" """
Assign index for transpose op. Assign index for transpose op.
1. swap input's dim according to transpose args 1. swap input's dim according to transpose args
@ -241,10 +241,10 @@ class TraceIndice(object):
tranpose_dim = node.args[1:] tranpose_dim = node.args[1:]
self._assign_index_as_input(node, node_idx, input_node) self._assign_index_as_input(node, node_idx, input_node)
self._inherit_index(input_node, tranpose_dim[1], node, tranpose_dim[0]) self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
self._inherit_index(input_node, tranpose_dim[0], node, tranpose_dim[1]) self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
def _assign_permute_index(self, node, node_idx): def _assign_permute_indice(self, node, node_idx):
""" """
Assign index for permute op. Assign index for permute op.
1. swap input's dim according to permute args 1. swap input's dim according to permute args
@ -259,9 +259,9 @@ class TraceIndice(object):
self._assign_index_as_input(node, node_idx, input_node) self._assign_index_as_input(node, node_idx, input_node)
for idx, d in enumerate(permute_dim): for idx, d in enumerate(permute_dim):
self._inherit_index(input_node, d, node, idx) self._inherit_indice(input_node, d, node, idx)
def _assign_linear_index(self, node, node_idx): def _assign_linear_indice(self, node, node_idx):
""" """
Assign index for linear op. Assign index for linear op.
1. copy trace from input node and change last index accroding to weight 1. copy trace from input node and change last index accroding to weight
@ -279,15 +279,15 @@ class TraceIndice(object):
input_node, weight, bias = node.args input_node, weight, bias = node.args
self._assign_index_as_input(node, node_idx) self._assign_index_as_input(node, node_idx)
self._inherit_index(weight, 1, node, -1) self._inherit_indice(weight, 1, node, -1)
self._mark_computation(node, node_idx, [-1]) self._mark_computation(node, node_idx, [-1])
self._mark_idx_equal(input_node, -1, weight, 0) self._mark_indice_equal(input_node, -1, weight, 0)
if bias: if bias:
self._mark_idx_equal(input_node, -1, bias, 0) self._mark_indice_equal(input_node, -1, bias, 0)
def _assign_matmul_index(self, node, node_idx): def _assign_matmul_indice(self, node, node_idx):
""" """
Assign index for matmul op. Assign index for matmul op.
1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length) 1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length)
@ -302,13 +302,13 @@ class TraceIndice(object):
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right)) assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
self._assign_index_as_input(node, node_idx, matmul_left) self._assign_index_as_input(node, node_idx, matmul_left)
self._inherit_index(matmul_right, -1, node, -1) self._inherit_indice(matmul_right, -1, node, -1)
self._mark_computation_from_node(matmul_right, node, [-1, -2]) self._mark_computation_from_node(matmul_right, node, [-1, -2])
self._mark_computation(node, node_idx, [-1]) self._mark_computation(node, node_idx, [-1])
self._mark_idx_equal(matmul_left, -1, matmul_right, -2) self._mark_indice_equal(matmul_left, -1, matmul_right, -2)
def _assign_layernorm_index(self, node, idx): def _assign_layernorm_indice(self, node, idx):
""" """
Assign index for layernorm op. Assign index for layernorm op.
1. assign index as input node 1. assign index as input node
@ -321,7 +321,7 @@ class TraceIndice(object):
self._assign_index_as_input(node, idx) self._assign_index_as_input(node, idx)
self._mark_computation(node, idx, [-1]) self._mark_computation(node, idx, [-1])
def _assign_elementwise_index(self, node, idx): def _assign_elementwise_indice(self, node, idx):
""" """
Assign index for element-wise op (eg. relu sigmoid add mul). Assign index for element-wise op (eg. relu sigmoid add mul).
1. assign index as input node 1. assign index as input node
@ -343,15 +343,15 @@ class TraceIndice(object):
node_in1_shape = get_node_shape(nodes_in[1]) node_in1_shape = get_node_shape(nodes_in[1])
for i in range(-1, -min(len(node_in0_shape), len(node_in1_shape)) - 1, -1): for i in range(-1, -min(len(node_in0_shape), len(node_in1_shape)) - 1, -1):
if node_in0_shape[i] == node_in1_shape[i]: if node_in0_shape[i] == node_in1_shape[i]:
self._mark_idx_equal(nodes_in[0], i, nodes_in[1], i) self._mark_indice_equal(nodes_in[0], i, nodes_in[1], i)
def _assgin_no_change_index(self, node, idx): def _assgin_no_change_indice(self, node, idx):
self._assign_index_as_input(node, idx) self._assign_index_as_input(node, idx)
for node_in in node.args: for node_in in node.args:
if type(node_in) == type(node): if type(node_in) == type(node):
self._mark_computation_from_node(node_in, node) self._mark_computation_from_node(node_in, node)
def _assign_einsum_index(self, node, idx): def _assign_einsum_indice(self, node, idx):
""" """
Assign index for einsum op. Assign index for einsum op.
@ -378,7 +378,7 @@ class TraceIndice(object):
for left_idx, left_str in enumerate(left): for left_idx, left_str in enumerate(left):
if right_indice in left_str: if right_indice in left_str:
source_idx = left_str.index(right_indice) source_idx = left_str.index(right_indice)
self._inherit_index( self._inherit_indice(
input_nodes[left_idx], source_idx, node, right_idx input_nodes[left_idx], source_idx, node, right_idx
) )
@ -388,7 +388,7 @@ class TraceIndice(object):
# self._mark_computation(node, idx, left_str.index(i)) # self._mark_computation(node, idx, left_str.index(i))
# break # break
def _assign_softmax_index(self, node, idx): def _assign_softmax_indice(self, node, idx):
""" """
Assign index for softmax op. Assign index for softmax op.
1. assign index as input node 1. assign index as input node
@ -401,7 +401,7 @@ class TraceIndice(object):
self._assign_index_as_input(node, idx) self._assign_index_as_input(node, idx)
self._mark_computation(node, idx, [node.kwargs["dim"]]) self._mark_computation(node, idx, [node.kwargs["dim"]])
def _assign_unsqueeze_index(self, node, node_idx): def _assign_unsqueeze_indice(self, node, node_idx):
""" """
Assign index for unsqueeze op. Assign index for unsqueeze op.
1. assign new index for unsqueeze dim 1. assign new index for unsqueeze dim
@ -414,7 +414,7 @@ class TraceIndice(object):
self._assign_index_as_input(node, node_idx) self._assign_index_as_input(node, node_idx)
self._add_dim(node_idx, node.args[1]) self._add_dim(node_idx, node.args[1])
def _assign_dropout_index(self, node, node_idx): def _assign_dropout_indice(self, node, node_idx):
""" """
Assign index for unsqueeze op. Assign index for unsqueeze op.
1. assign new index for unsqueeze dim 1. assign new index for unsqueeze dim
@ -425,7 +425,7 @@ class TraceIndice(object):
""" """
self._assign_index_as_input(node, node_idx) self._assign_index_as_input(node, node_idx)
def _assign_ones_like_index(self, node, node_idx): def _assign_ones_like_indice(self, node, node_idx):
""" """
Assign index for oneslike op. Assign index for oneslike op.
1. assign new index for all dim 1. assign new index for all dim
@ -434,9 +434,9 @@ class TraceIndice(object):
node (node) node (node)
node_idx (int) node_idx (int)
""" """
self._assign_all_index(node, node_idx) self._assign_all_indice(node, node_idx)
def _assign_view_reshape_index(self, node, node_idx): def _assign_view_reshape_indice(self, node, node_idx):
""" """
Assign index for view and reshape op. Assign index for view and reshape op.
1. get origin shape and target shape by meta info. 1. get origin shape and target shape by meta info.
@ -496,7 +496,7 @@ class TraceIndice(object):
) )
# get new index # get new index
origin_trace = self._find_idx_trace_from_node(origin_node) origin_trace = self._find_indice_trace_from_node(origin_node)
self._assign_index_as_input(node, node_idx, origin_node) self._assign_index_as_input(node, node_idx, origin_node)
dim_from.reverse() dim_from.reverse()
for i in dim_from: for i in dim_from:
@ -516,18 +516,18 @@ class TraceIndice(object):
view_dict = { view_dict = {
"idx_from": [origin_trace[i] for i in dim_from], "idx_from": [origin_trace[i] for i in dim_from],
"dim_from": dim_from, "dim_from": dim_from,
"idx_to": [self.idx_trace_list[node_idx]["idx"][i] for i in dim_to], "idx_to": [self.indice_trace_list[node_idx]["idx"][i] for i in dim_to],
"dim_to": dim_to, "dim_to": dim_to,
} }
self.idx_view_list[node] = view_dict self.indice_view_list[node] = view_dict
def _merge_equal_idx(self): def _merge_equal_idx(self):
idx_equal = copy.deepcopy(self.idx_trace_equal) idx_equal = copy.deepcopy(self.indice_trace_equal)
idx_equal.reverse() idx_equal.reverse()
for idx in idx_equal: for idx in idx_equal:
merge_to = min(idx) merge_to = min(idx)
merge_from = max(idx) merge_from = max(idx)
for trace in self.idx_trace_list: for trace in self.indice_trace_list:
if merge_from in trace["idx"]: if merge_from in trace["idx"]:
trace["idx"] = [ trace["idx"] = [
merge_to if i == merge_from else i for i in trace["idx"] merge_to if i == merge_from else i for i in trace["idx"]
@ -536,35 +536,35 @@ class TraceIndice(object):
def trace_index(self): def trace_index(self):
for idx, node in enumerate(self.node_list): for idx, node in enumerate(self.node_list):
if node.op == "placeholder": if node.op == "placeholder":
self._assign_all_index(node, idx) self._assign_all_indice(node, idx)
elif node.op == "call_method": elif node.op == "call_method":
if "transpose" in node.name: if "transpose" in node.name:
self._assign_transpose_index(node, idx) self._assign_transpose_indice(node, idx)
elif "permute" in node.name: elif "permute" in node.name:
self._assign_permute_index(node, idx) self._assign_permute_indice(node, idx)
elif "view" in node.name or "reshape" in node.name: elif "view" in node.name or "reshape" in node.name:
self._assign_view_reshape_index(node, idx) self._assign_view_reshape_indice(node, idx)
elif "unsqueeze" in node.name: elif "unsqueeze" in node.name:
self._assign_unsqueeze_index(node, idx) self._assign_unsqueeze_indice(node, idx)
elif any(i in node.name for i in ["to", "contiguous"]): elif any(i in node.name for i in ["to", "contiguous"]):
self._assgin_no_change_index(node, idx) self._assgin_no_change_indice(node, idx)
else: else:
raise NotImplementedError(node.name, "method not implemented yet!") raise NotImplementedError(node.name, "method not implemented yet!")
elif node.op == "call_function": elif node.op == "call_function":
if "linear" in node.name: if "linear" in node.name:
self._assign_linear_index(node, idx) self._assign_linear_indice(node, idx)
elif "matmul" in node.name: elif "matmul" in node.name:
self._assign_matmul_index(node, idx) self._assign_matmul_indice(node, idx)
elif "softmax" in node.name: elif "softmax" in node.name:
self._assign_softmax_index(node, idx) self._assign_softmax_indice(node, idx)
elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu"]): elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu"]):
self._assign_elementwise_index(node, idx) self._assign_elementwise_indice(node, idx)
elif "ones_like" in node.name: elif "ones_like" in node.name:
self._assign_ones_like_index(node, idx) self._assign_ones_like_indice(node, idx)
elif "dropout" in node.name: elif "dropout" in node.name:
self._assign_dropout_index(node, idx) self._assign_dropout_indice(node, idx)
elif "einsum" in node.name: elif "einsum" in node.name:
self._assign_einsum_index(node, idx) self._assign_einsum_indice(node, idx)
elif "getattr" in node.name: elif "getattr" in node.name:
continue # get attr like shape continue # get attr like shape
elif "getitem" in node.name: elif "getitem" in node.name:
@ -575,11 +575,11 @@ class TraceIndice(object):
) )
elif node.op == "call_module": elif node.op == "call_module":
if any(n in node.name for n in ["layernorm", "norm"]): if any(n in node.name for n in ["layernorm", "norm"]):
self._assign_layernorm_index(node, idx) self._assign_layernorm_indice(node, idx)
else: else:
raise NotImplementedError(node.name, "module not implemented yet!") raise NotImplementedError(node.name, "module not implemented yet!")
elif node.op == "get_attr": elif node.op == "get_attr":
self._assign_all_index(node, idx) # get param self._assign_all_indice(node, idx) # get param
elif node.op == "output": elif node.op == "output":
continue continue
else: else: