import copy from typing import Dict, List from torch.fx.node import Node from .utils import NodeMgr, find_first_tensor_arg, flat_list, get_module_node_name, get_node_name, get_node_shape class TraceIndice(object): """ Trace all indice information for every node. Indice is a logical concept. Equal dims can been treated as one indice. eg. dim(x1) = [a, b, c] dim(x2) = [d, e, f] and we have x3 = x1 * x2. then a=d, b=e, c=f, due to the broadcast property, dim(x1)=dim(x2)=dim(x3)=[a, b, c] This class will record every node's dims' indice, compute and source. Attributes: node_list (List) indice_trace_list (List): [{"indice": [...], "compute": [...], "source": [...]}, {...}] indice_view_list (Dict): not used for now indice_count (int): record indice number Args: node_list (List) """ def __init__(self, node_mgr: NodeMgr) -> None: self.node_mgr = node_mgr self.indice_trace_list = self._init_indice_trace_list() self.indice_view_list = {} self.indice_count = -1 self.active_node_list = [] def _init_indice_trace_list(self) -> List: indice_trace_list = [] for n in self.node_mgr.get_node_list(): if get_node_shape(n) != None: cur_trace = { "indice": [None for _ in range(len(get_node_shape(n)))], "compute": [[] for _ in range(len(get_node_shape(n)))], "source": [{} for _ in range(len(get_node_shape(n)))], } else: cur_trace = {"indice": [], "compute": [], "source": []} indice_trace_list.append(cur_trace) return indice_trace_list def set_active_nodes(self, active_node_list: List) -> None: self.active_node_list = active_node_list def _add_indice(self) -> int: """ Update the count and return it. To record the idx number. Returns: indice_count: int """ self.indice_count += 1 return self.indice_count def _del_dim(self, idx: int, dim_idx: int) -> None: """ delete a dim for indice, compute and source """ self.indice_trace_list[idx]["indice"].pop(dim_idx) self.indice_trace_list[idx]["compute"].pop(dim_idx) self.indice_trace_list[idx]["source"].pop(dim_idx) def _add_dim(self, node_idx: int, dim_idx: int) -> None: """ add a dim for indice, compute and source """ # need to remap if dim_idx < 0, e.g. -1 if dim_idx < 0: dim_idx = list(range(len(self.indice_trace_list[node_idx]["indice"]) + 1))[dim_idx] self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice()) self.indice_trace_list[node_idx]["compute"].insert(dim_idx, []) self.indice_trace_list[node_idx]["source"].insert(dim_idx, {}) def _add_source( self, node_from: Node, node_from_dim: int, node_to: Node, node_to_dim: int, init=False, ) -> None: node_from_dim = self._transform_indice(node_from, node_from_dim) node_from_trace_source = self._find_source_trace_from_node(node_from) node_to_dim = self._transform_indice(node_to, node_to_dim) node_to_trace_source = self._find_source_trace_from_node(node_to) node_from_idx = self.node_mgr.find_node_idx(node_from) if init: node_to_trace_source[node_to_dim] = {} # add dim to cur new source if node_from_idx not in node_to_trace_source[node_to_dim]: node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim] else: if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]: node_to_trace_source[node_to_dim][node_from_idx].append(node_from_dim) # update inputs source for node_idx, node_dim in node_from_trace_source[node_from_dim].items(): if node_idx not in node_to_trace_source[node_to_dim]: node_to_trace_source[node_to_dim][node_idx] = copy.deepcopy(node_dim) else: for d in node_dim: if d not in node_to_trace_source[node_to_dim][node_idx]: node_to_trace_source[node_to_dim][node_idx].append(d) def _transform_indice(self, node: Node, node_dim: int) -> int: node_idx = self._find_indice_trace_from_node(node) dims = list(range(len(node_idx))) return dims[node_dim] def _inherit_indice( self, node_from: Node, node_from_dim: int, node_to: Node, node_to_dim: int, init: bool = True, ) -> None: """ node_to's node_to_dim inherit node_from's node_from_dim by indice, compute and source """ node_from_dim = self._transform_indice(node_from, node_from_dim) node_to_dim = self._transform_indice(node_to, node_to_dim) node_from_trace = self._find_trace_from_node(node_from) node_to_trace = self._find_trace_from_node(node_to) if init: node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim] node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim]) else: for j in node_from_trace["compute"][node_from_dim]: if j not in node_to_trace["compute"][node_to_dim]: node_to_trace["compute"][node_to_dim].append(j) self._add_source(node_from, node_from_dim, node_to, node_to_dim, init) def _inherit_all_indice(self, node_from: Node, node_to: Node) -> None: """ inherit all dims with init """ # find indice just for assert length node_from_indice = self._find_indice_trace_from_node(node_from) node_to_indice = self._find_indice_trace_from_node(node_to) assert len(node_from_indice) == len(node_to_indice) for i in range(len(node_from_indice)): self._inherit_indice(node_from, i, node_to, i, init=True) def _inherit_more_indice_from_node_with_exclude(self, node_from: Node, node_to: Node, exclude: List = None) -> None: """ inherit indice from node without init """ if exclude == None: exclude = [] else: exclude = [self._transform_indice(node_to, i) for i in exclude] node_from_compute = self._find_compute_trace_from_node(node_from) node_to_compute = self._find_compute_trace_from_node(node_to) # assert len(node_from_compute) == len(node_to_compute) for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1): if self._transform_indice(node_to, i) in exclude: continue self._inherit_indice(node_from, i, node_to, i, init=False) def _mark_computation(self, node: Node, idx: int, dim: int) -> None: """ Mark some dims of node as computed. Args: node (node) idx (int): node index dim (list or int): dims to be marked as computed """ if isinstance(dim, int): dim = [dim] dims = list(range(len(get_node_shape(node)))) for d in dim: cur_dim = dims[d] if idx not in self.indice_trace_list[idx]["compute"][cur_dim]: self.indice_trace_list[idx]["compute"][cur_dim].append(idx) def _find_trace_from_node(self, node: Node) -> Dict: """ Find node idx and compute trace by the node. Args: node (node) Returns: idx (list): idx of the node compute (list): computed idx of the node. """ node_idx = self.node_mgr.find_node_idx(node) node_dict = self.indice_trace_list[node_idx] return node_dict def _find_source_trace_from_node(self, node: Node) -> List: """ Find node source trace by the node. Args: node (node) Returns: idx (list): idx of the node compute (list): computed idx of the node. """ node_idx = self.node_mgr.find_node_idx(node) node_dict = self.indice_trace_list[node_idx] return node_dict["source"] def _find_indice_trace_from_node(self, node) -> List: """ Find node idx trace by the node. Args: node (node) Returns: idx (list): idx of the node """ node_idx = self.node_mgr.find_node_idx(node) return self.indice_trace_list[node_idx]["indice"] def _find_compute_trace_from_node(self, node: Node) -> List: """ Find node compute trace by the node. Args: node (node) Returns: compute (list): computed idx of the node. """ node_idx = self.node_mgr.find_node_idx(node) return self.indice_trace_list[node_idx]["compute"] def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None) -> None: """ Assign node's trace as its input node. Args: node (node) node_idx (int) """ if input_node == None: input_node = find_first_tensor_arg(node) self._inherit_all_indice(input_node, node) def _assign_all_indice(self, node: Node, node_idx: int) -> None: """ Add new indice for all node's dims. Args: node (node) node_idx (int) """ shape = node.meta["tensor_meta"].shape if shape is None: return new_trace = [] for _ in shape: new_trace.append(self._add_indice()) self.indice_trace_list[node_idx]["indice"] = new_trace def _assign_transpose_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for transpose op. 1. swap input's dim according to transpose args 2. inherit input's computation Args: node (node) node_idx (int) """ input_node = node.args[0] tranpose_dim = node.args[1:] self._assign_indice_as_input(node, node_idx, input_node) self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0]) self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1]) def _assign_permute_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for permute op. 1. swap input's dim according to permute args 2. inherit input's computation Args: node (node) node_idx (int) """ permute_dim = flat_list(node.args[1:]) input_node = node.args[0] self._assign_indice_as_input(node, node_idx, input_node) for idx, d in enumerate(permute_dim): self._inherit_indice(input_node, d, node, idx) def _assign_linear_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for linear op. 1. copy trace from input node and change last indice according to weight 2. mark equal for input node last indice, weight first dim and bias dim. 3. inherit input's computation, mark computation for last dim. Args: node (node) node_idx (int) """ self._assign_indice_as_input(node, node_idx) if len(node.args) >= 2: weight = node.args[1] self._inherit_indice(weight, 1, node, -1) else: self._del_dim(node_idx, -1) self._add_dim(node_idx, -1) self._mark_computation(node, node_idx, [-1]) def _assign_addmm_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for addmm op. Args: node (node) node_idx (int) """ bias, input_node, weight = node.args assert len(get_node_shape(bias)) == 1 and len(get_node_shape(weight)) == 2 self._assign_indice_as_input(node, node_idx, input_node) self._inherit_indice(weight, 1, node, -1) self._inherit_more_indice_from_node_with_exclude(bias, node) self._mark_computation(node, node_idx, [-1]) def _assign_baddbmm_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for baddbmm(batch add and batch matmul) op. add, matmul_left, matmul_right = args out = add + (matmul_left x matmul_right) Args: node (node) node_idx (int) """ add, matmul_left, matmul_right = node.args assert get_node_shape(add) == get_node_shape(node) assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right)) self._assign_indice_as_input(node, node_idx, matmul_left) # matmul self._inherit_indice(matmul_right, -1, node, -1) self._inherit_more_indice_from_node_with_exclude(matmul_right, node, [-2, -1]) self._mark_computation(node, node_idx, [-1]) # add self._inherit_more_indice_from_node_with_exclude(add, node) def _assign_matmul_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for matmul op. 1. copy trace from matmul_left and change last indice according to matmul_right. (assert they have same length) 2. mark equal for input matmul_left -1 indice and matmul_right -2 dim. 3. inherit matmul_left and matmul_right computation, mark computation for last dim. Args: node (node) node_idx (int) """ matmul_left, matmul_right = node.args assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right)) self._assign_indice_as_input(node, node_idx, matmul_left) self._inherit_indice(matmul_right, -1, node, -1) self._inherit_more_indice_from_node_with_exclude(matmul_right, node, [-1, -2]) self._mark_computation(node, node_idx, [-1]) def _assign_conv2d_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for conv2d op. Args: node (node) node_idx (int) """ # get conv module node_targets = node.target.split(".") conv_module = node.graph.owning_module for i in node_targets: conv_module = getattr(conv_module, i) assert conv_module.dilation == (1, 1), "dilation for conv2d not implemented" # get conv input assert len(node.args) == 1 input_node = node.args[0] assert len(get_node_shape(input_node)) == 4 # assign index self._assign_indice_as_input(node, node_idx, input_node) self._del_dim(node_idx, 1) self._add_dim(node_idx, 1) self._mark_computation(node, node_idx, [1, 2, 3]) def _assign_interpolate_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for interpolate op. Args: node (node) node_idx (int) """ # get conv input assert node.kwargs["size"] is None assert len(get_node_shape(node)) == 4 # assign index self._assign_indice_as_input(node, node_idx) self._mark_computation(node, node_idx, [-1, -2]) def _assign_layernorm_indice(self, node, idx): """ Assign indice for layernorm op. 1. assign indice as input node 2. inherit computation and mark last 2 dims as computed. Args: node (node) node_idx (int) """ self._assign_indice_as_input(node, idx) self._mark_computation(node, idx, [-1]) def _assign_groupnorm_indice(self, node, idx): """ Assign indice for groupnorm op. Args: node (node) node_idx (int) """ assert len(get_node_shape(node)) == 4 self._assign_indice_as_input(node, idx) self._mark_computation(node, idx, [-1, -2, -3]) def _assign_elementwise_indice(self, node, idx): """ Assign indice for element-wise op (eg. relu sigmoid add mul). 1. assign indice as input node 2. inherit computation from all input nodes. Args: node (node) node_idx (int) """ self._assign_indice_as_input(node, idx) nodes_in = [] for node_in in node.args: if type(node_in) == type(node): nodes_in.append(node_in) self._inherit_more_indice_from_node_with_exclude(node_in, node) def _assign_no_change_indice(self, node, idx): self._assign_indice_as_input(node, idx) for node_in in node.args: if type(node_in) == type(node): self._inherit_more_indice_from_node_with_exclude(node_in, node) def _assign_einsum_indice(self, node, idx): """ Assign indice for einsum op. Args: node (node) node_idx (int) """ patterns = node.args[0] input_nodes = node.args[1:] patterns = patterns.replace(" ", "") left, right = patterns.split("->") left = left.split(",") if "..." in right: replace_list = "!@#$%^&*" target_len = len(get_node_shape(node)) add_len = target_len - len(right) + 3 replace_str = replace_list[:add_len] right = right.replace("...", replace_str) for ll in range(len(left)): left[ll] = left[ll].replace("...", replace_str) all_index = [] for i in left: for c in i: all_index.append(c) all_index = set(all_index) for right_idx, right_indice in enumerate(right): for left_idx, left_str in enumerate(left): if right_indice in left_str: source_idx = left_str.index(right_indice) self._inherit_indice(input_nodes[left_idx], source_idx, node, right_idx) def _assign_softmax_indice(self, node, idx): """ Assign indice for softmax op. 1. assign indice as input node 2. inherit computation and mark softmax dim as computed. Args: node (node) node_idx (int) """ self._assign_indice_as_input(node, idx) self._mark_computation(node, idx, [node.kwargs["dim"]]) def _assign_split_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for split op. Args: node (node) node_idx (int) """ self._assign_indice_as_input(node, node_idx) dim_idx = node.kwargs["dim"] self._del_dim(node_idx, dim_idx) self._add_dim(node_idx, dim_idx) def _assign_unsqueeze_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for unsqueeze op. 1. assign new indice for unsqueeze dim Args: node (node) node_idx (int) """ self._del_dim(node_idx, -1) self._assign_indice_as_input(node, node_idx) dim_idx = node.args[1] # unsqueeze(-1) = unsqueeze(shape_num + 1) if dim_idx < 0: dim_idx = list(range(len(get_node_shape(node))))[dim_idx] self._add_dim(node_idx, dim_idx) def _assign_cat_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for cat op. Args: node (node) node_idx (int) """ nodes_in = flat_list(node.args[0]) self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0]) for n in nodes_in[1:]: self._inherit_more_indice_from_node_with_exclude(n, node) cat_dim = node.kwargs["dim"] self._del_dim(node_idx, cat_dim) self._add_dim(node_idx, cat_dim) def _assign_sum_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for sum op. Args: node (node) node_idx (int) """ nodes_in = flat_list(node.args[0]) self._add_dim(node_idx, 0) self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0]) for n in nodes_in[1:]: self._inherit_more_indice_from_node_with_exclude(n, node) cat_dim = node.kwargs["dim"] self._del_dim(node_idx, cat_dim) def _assign_flatten_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for flatten op. Args: node (node) node_idx (int) """ nodes_in = node.args[0] nodes_in_shape = get_node_shape(nodes_in) flatten_start_dim = node.args[1] flatten_dim_num = len(nodes_in_shape) - flatten_start_dim - 1 assert flatten_dim_num > 0 for _ in range(flatten_dim_num): self._add_dim(node_idx, 0) self._assign_indice_as_input(node, node_idx, nodes_in) for _ in range(flatten_dim_num + 1): self._del_dim(node_idx, -1) self._add_dim(node_idx, -1) def _assign_expand_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for expand op. Args: node (node) node_idx (int) """ expand_shape = node.args[1:] node_in_shape = get_node_shape(node.args[0]) assert len(expand_shape) == len(node_in_shape) self._assign_indice_as_input(node, node_idx) for i in range(len(node_in_shape)): if expand_shape[i] == node_in_shape[i] or expand_shape[i] == -1: continue elif expand_shape[i] > node_in_shape[i]: self._del_dim(node_idx, i) self._add_dim(node_idx, i) else: raise RuntimeError() def _assign_unbind_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for unbind op. Args: node (node) node_idx (int) """ unbind_dim = node.args[1] self._add_dim(node_idx, unbind_dim) self._assign_indice_as_input(node, node_idx) self._del_dim(node_idx, unbind_dim) def _assign_embedding_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for embedding op. Args: node (node) node_idx (int) """ self._del_dim(node_idx, -1) self._assign_indice_as_input(node, node_idx) self._add_dim(node_idx, -1) def _assign_getitem_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for getitem. getitem can act like slice sometimes Args: node (node) node_idx (int) """ node_args = flat_list(node.args[1:]) # deal with split if get_node_name(node.args[0]) == "split": self._assign_indice_as_input(node, node_idx) self._del_dim(node_idx, node.args[0].kwargs["dim"]) self._add_dim(node_idx, node.args[0].kwargs["dim"]) return # skip non tensor if get_node_shape(node) is None: return # find if slice flag = False for node_arg in node_args: node_arg_str = str(node_arg) if any(i == node_arg_str for i in ["None", "Ellipsis"]): flag = True break if "slice" in node_arg_str: flag = True break if flag == False: return # node args should be like [Ellipsis, slice(start, step, end), None] node_shape = get_node_shape(node) origin_idx_count = 0 new_idx_count = 0 new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args]) for _ in range(new_dim_num): self._del_dim(node_idx, 0) delete_dim_num = sum([1 if str(i) == "0" else 0 for i in node_args]) for _ in range(delete_dim_num): self._add_dim(node_idx, 0) self._assign_indice_as_input(node, node_idx) for _, node_arg in enumerate(node_args): node_arg_str = str(node_arg) # Ellipsis means [..., ] if "Ellipsis" == node_arg_str: shape_gap = len(node_shape) - len(node_args) + 1 origin_idx_count += shape_gap new_idx_count += shape_gap # slice(None, None, None) means all indexes elif "slice" in node_arg_str: if "slice(None, None, None)" != node_arg_str: self._del_dim(node_idx, new_idx_count) self._add_dim(node_idx, new_idx_count) origin_idx_count += 1 new_idx_count += 1 # None means a new dim elif "None" == node_arg_str: self._add_dim(node_idx, new_idx_count) new_idx_count += 1 elif "0" == node_arg_str: self._del_dim(node_idx, new_idx_count) origin_idx_count += 1 else: raise NotImplementedError() def _assign_view_reshape_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for view and reshape op. 1. get origin shape and target shape by meta info. 2. compute the real value of -1 in target shape. 3. determine changed dim, and assign indice for generated dim. 4. log changed dim and generated dim for restore 5. inherit computation. 6. look into view list to see whether the view is associated with other, if so assign equal dim according to previous view. Args: node (node) node_idx (int) """ # get data, turn into number origin_node = node.args[0] origin_shape = origin_node.meta["tensor_meta"].shape target_shape = [] unflated_args = flat_list(node.args) for i in range(1, len(unflated_args)): if isinstance(unflated_args[i], int): target_shape.append(unflated_args[i]) else: target_shape.extend(unflated_args[i].meta["fwd_out"]) # compute the value of -1 if -1 in target_shape: origin_product = 1 for i in origin_shape: origin_product *= i target_product = -1 for i in target_shape: target_product *= i shape_idx = target_shape.index(-1) target_shape[shape_idx] = origin_product // target_product # find same dim dim_to_same_dim = [] dim_from_same_dim = [] for i in range(len(origin_shape)): if origin_shape[i] == target_shape[i]: dim_to_same_dim.append(i) dim_from_same_dim.append(i) else: break for i in range(-1, -len(origin_shape), -1): if origin_shape[i] == target_shape[i]: dim_to_same_dim.append(len(target_shape) + i) dim_from_same_dim.append(len(origin_shape) + i) else: break dim_from = list(set(range(len(origin_shape))) - set(dim_from_same_dim)) dim_to = list(set(range(len(target_shape))) - set(dim_to_same_dim)) assert len(dim_from) == 1 or len(dim_to) == 1 or len(dim_from) == len(dim_to) dim_diff = len(dim_from) - len(dim_to) if dim_diff > 0: # dim merge for i in range(dim_diff): self._add_dim(node_idx, -1) elif dim_diff < 0: # dim expand for i in range(-dim_diff): self._del_dim(node_idx, -1) # get new indice origin_trace = self._find_indice_trace_from_node(origin_node) self._assign_indice_as_input(node, node_idx, origin_node) dim_from.reverse() for i in dim_from: self._del_dim(node_idx, i) for i in dim_to: self._add_dim(node_idx, i) dim_from.reverse() # inherit indice from current node if len(dim_from) != 0 and len(dim_to) != 0: if dim_diff == 1: if origin_shape[dim_from[0]] == 1: self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False) elif origin_shape[dim_from[1]] == 1: self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False) elif dim_diff == -1: if target_shape[dim_to[0]] == 1: self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False) elif target_shape[dim_to[1]] == 1: self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False) # log view, not used now view_dict = { "idx_from": [origin_trace[i] for i in dim_from], "dim_from": dim_from, "idx_to": [self.indice_trace_list[node_idx]["indice"][i] for i in dim_to], "dim_to": dim_to, } self.indice_view_list[node] = view_dict def _clear_trace(self, node_idx: int) -> None: """ clear too far trace to speed up computation """ trace_barrier = max(node_idx - 100, 0) active_nodes = self.active_node_list[trace_barrier] active_nodes = [self.node_mgr.find_node_idx(i) for i in active_nodes.keys()] trace = self.indice_trace_list[node_idx] # clear compute for dim_compute in trace["compute"]: for i in range(len(dim_compute) - 1, -1, -1): if dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes: dim_compute.pop(i) continue # clear source for dim_source in trace["source"]: for k in list(dim_source.keys()): if k < trace_barrier and k not in active_nodes: dim_source.pop(k) def trace_indice(self) -> None: for idx, node in enumerate(self.node_mgr.get_node_list()): node_name = get_node_name(node) if node.op == "placeholder": self._assign_all_indice(node, idx) elif node.op == "call_method": if "transpose" == node_name: self._assign_transpose_indice(node, idx) elif "permute" == node_name: self._assign_permute_indice(node, idx) elif "view" == node_name or "reshape" == node_name: self._assign_view_reshape_indice(node, idx) elif "unsqueeze" == node_name: self._assign_unsqueeze_indice(node, idx) elif "split" == node_name: self._assign_split_indice(node, idx) elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]): self._assign_no_change_indice(node, idx) elif "new_ones" == node_name: self._assign_all_indice(node, idx) elif "flatten" == node_name: self._assign_flatten_indice(node, idx) elif "expand" == node_name: self._assign_expand_indice(node, idx) elif "unbind" == node_name: self._assign_unbind_indice(node, idx) elif "softmax" == node_name: self._assign_softmax_indice(node, idx) elif any(i == node_name for i in ["size"]): continue else: raise NotImplementedError(node_name, "method not implemented yet!") elif node.op == "call_function": if "linear" == node_name: self._assign_linear_indice(node, idx) elif "cat" == node_name: self._assign_cat_indice(node, idx) elif any(n == node_name for n in ["matmul", "bmm"]): self._assign_matmul_indice(node, idx) elif "softmax" == node_name: self._assign_softmax_indice(node, idx) elif any( n == node_name for n in [ "mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp", "sin", "cos", ] ): self._assign_elementwise_indice(node, idx) elif "einsum" == node_name: self._assign_einsum_indice(node, idx) elif "sum" == node_name: self._assign_sum_indice(node, idx) elif "layer_norm" == node_name: self._assign_layernorm_indice(node, idx) elif "getitem" == node_name: self._assign_getitem_indice(node, idx) elif "addmm" == node_name: self._assign_addmm_indice(node, idx) elif "baddbmm" == node_name: self._assign_baddbmm_indice(node, idx) elif "interpolate" == node_name: self._assign_interpolate_indice(node, idx) elif any(i == node_name for i in ["arange", "ones", "ones_like", "tensor", "empty"]): self._assign_all_indice(node, idx) elif any(i == node_name for i in ["getattr", "eq", "_assert_is_none", "_assert", "finfo"]): continue else: raise NotImplementedError(node_name, "function not implemented yet!") elif node.op == "call_module": node_name = get_module_node_name(node) if "layernorm" == node_name: self._assign_layernorm_indice(node, idx) elif "groupnorm" == node_name: self._assign_groupnorm_indice(node, idx) elif "embedding" == node_name: self._assign_embedding_indice(node, idx) elif "linear" == node_name: self._assign_linear_indice(node, idx) elif "conv2d" == node_name: self._assign_conv2d_indice(node, idx) elif "identity" == node_name: self._assign_no_change_indice(node, idx) elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]): self._assign_elementwise_indice(node, idx) else: raise NotImplementedError(node_name, "module not implemented yet!") elif node.op == "get_attr": self._assign_all_indice(node, idx) # get param elif node.op == "output": continue else: raise NotImplementedError(node.op, "op not implemented yet!") # limit trace range self._clear_trace(idx)