from typing import Dict, List, Tuple from torch.fx.node import Node from .trace_indice import TraceIndice from .utils import ( NodeMgr, find_chunk_all_input_nodes, find_chunk_compute_input_and_output_nodes, find_tensor_shape_node, flat_list, get_node_name, get_node_shape, is_non_compute_node, ) class TraceFlow(object): def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None: self.trace_indice = trace_indice self.node_mgr = node_mgr def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node): """ Check 2 given index: one index should be source of the other Args: start_idx(int): start node chunk dim start_node(node): start node end_idx(int): end node chunk dim end_node(node): end node Returns: bool: True if check pass """ # we use start_node_idx instead of real chunk index start_node_idx = self.node_mgr.find_node_idx(start_node) end_node_trace = self.trace_indice._find_trace_from_node(end_node) end_node_trace_source = end_node_trace["source"][end_dim] sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True) for node_idx, node_dim in sorted_source: if node_idx == start_node_idx and start_dim in node_dim: return True # it means we meet a node outside the loop, and the node is not input node if node_idx < start_node_idx: return False return False def check_index_compute(self, start_idx, end_dim, end_node, end_idx): """ Check 2 given index: check they haven't been computed in the source trace. Args: start_idx(int): start node chunk dim start_node(node): start node end_idx(int): end node chunk dim end_node(node): end node Returns: bool: True if check pass """ end_node_trace = self.trace_indice._find_trace_from_node(end_node) end_node_compute = end_node_trace["compute"][end_dim] if any(start_idx <= i <= end_idx for i in end_node_compute): return False return True def _assgin_single_node_flow( self, arg_node: Node, start_idx: int, end_idx: int, cur_node: Node, cur_node_dim: int, cur_node_compute: Dict, cur_node_source: Dict, cur_node_fix_dim: List, all_node_info: Dict, next_node_list: List, ) -> bool: """ Given the current node and one of its arg node, this function finds out arg node's chunk dim and fix dim Args: arg_node (Node): input node start_idx (int): chunk region start end_idx (int): chunk region end cur_node_dim (int): current node chunk dim cur_node_compute (Dict): current node compute dict cur_node_source (Dict): current node source dict cur_node_fix_dim (List): current node fix dim all_node_info (Dict): all node chunk info in the chunk region next_node_list (List) Returns: bool: True if this node can be added to the flow, vice versa. """ arg_idx = self.node_mgr.find_node_idx(arg_node) # arg in chunk range or be inputs if not (start_idx <= arg_idx < end_idx): return True # get fix dim arg_fix_dim = [] if cur_node_dim is not None: for i in cur_node_fix_dim: fix_dim_source = cur_node_source[i] if arg_idx in fix_dim_source: arg_fix_dim.append(fix_dim_source[arg_idx][0]) if arg_node in all_node_info: arg_fix_dim = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)) # find arg dim if cur_node_dim is not None: # dim is computed if arg_idx in cur_node_compute[cur_node_dim]: return False if arg_idx not in cur_node_source[cur_node_dim]: arg_dim = None else: arg_dim = cur_node_source[cur_node_dim][arg_idx][0] # chunk dim cannot be in fix dims if arg_dim in arg_fix_dim: return False # chunk dim should be None if shape size is 1 if get_node_shape(arg_node)[arg_dim] == 1: arg_dim = None # chunk shape should equal cur node elif get_node_shape(arg_node)[arg_dim] != 1: if cur_node_dim is not None and get_node_shape(cur_node)[cur_node_dim] != 1: if get_node_shape(arg_node)[arg_dim] != get_node_shape(cur_node)[cur_node_dim]: return False else: arg_dim = None # add arg rest dim as fix dim arg_fix_dim = list(range(len(get_node_shape(arg_node)))) if arg_dim is not None: arg_fix_dim.remove(arg_dim) # if already in node_info, arg dim must be same if arg_node in all_node_info: if all_node_info[arg_node]["chunk_dim"] != arg_dim: return False all_node_info[arg_node]["fix_dim"] = arg_fix_dim # else add it to list else: all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim} next_node_list.append(arg_node) return True def _get_all_node_info(self, end_dim, start_idx, end_idx): cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} while len(cur_node_list) > 0: next_node_list = [] for cur_node in cur_node_list: # get cur node info cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] if cur_node_chunk_dim is not None: cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node) cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node) else: cur_node_compute = cur_node_source = None # get all valid args arg_list = [] for arg in cur_node.all_input_nodes: if type(arg) != type(cur_node): continue if is_non_compute_node(arg): continue if get_node_shape(arg) is None: continue arg_list.append(arg) flow_flag = self._assgin_single_node_flow( arg, start_idx, end_idx, cur_node, cur_node_chunk_dim, cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info, next_node_list, ) if flow_flag == False: return None cur_node_list = next_node_list return all_node_info def _get_input_nodes_dim(self, inputs: List[Node], start_idx: int, end_idx: int, all_node_info: Dict) -> Tuple: """ Get chunk dim for every input node for their every entry, remove unchunked nodes Args: inputs (List[Node]): input nodes all_node_info (Dict): describe all node's chunk dim and fix dim start_idx (int): chunk start idx end_idx (int): chunk end idx Returns: inputs (List(Node)): new inputs inputs_dim (List): chunk dim for inputs """ inputs_dim = [] remove_inputs = [] for input_node in inputs: input_dict = {} input_node_idx = self.node_mgr.find_node_idx(input_node) for user in input_node.users.keys(): # skip non compute if is_non_compute_node(user): continue # untraced node, mostly non compute if user not in all_node_info: continue user_idx = self.node_mgr.find_node_idx(user) if start_idx <= user_idx <= end_idx: chunk_dim = all_node_info[user]["chunk_dim"] if chunk_dim is not None: user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim] if input_node_idx in user_source: if get_node_shape(input_node)[user_source[input_node_idx][0]] == 1: input_dict[user_idx] = [None] else: input_dict[user_idx] = user_source[input_node_idx] else: return None, None if len(input_dict) == 0: remove_inputs.append(input_node) else: inputs_dim.append(input_dict) # remove unchunked inputs for i in remove_inputs: if i in inputs: inputs.remove(i) return inputs, inputs_dim def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int, chunk_info) -> List[Node]: """ get all useless nodes in chunk region and prepose them Args: all_node_info (Dict): describe all node's chunk dim and fix dim start_idx (int): chunk start idx end_idx (int): chunk end idx Returns: List[Node]: all nodes to be preposed """ # get all possible prepose nodes maybe_prepose_nodes = [] for node, node_info in all_node_info.items(): if node_info["chunk_dim"] is None: maybe_prepose_nodes.append(node) for node in self.node_mgr.get_node_slice_by_idx(start_idx, end_idx): if node not in all_node_info and node not in chunk_info["outputs"]: maybe_prepose_nodes.append(node) maybe_prepose_nodes.sort( key=lambda x: self.node_mgr.find_node_idx(x), reverse=True, ) # from last node to first node prepose_nodes = [] # set every node as root, search its args, if all legal, turn root and args as prepose nodes while len(maybe_prepose_nodes) > 0: tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]] tmp_cur_related_prepose_nodes = [] prepose_flag = True # loop cur node's all arg until out of chunk while len(tmp_cur_prepose_nodes) > 0: if prepose_flag == False: break tmp_next_prepose_nodes = [] tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes) for cur_prepose_node in tmp_cur_prepose_nodes: if prepose_flag == False: break for cur_prepose_node_arg in cur_prepose_node.all_input_nodes: if type(cur_prepose_node_arg) != type(cur_prepose_node): continue # out of loop if not (start_idx <= self.node_mgr.find_node_idx(cur_prepose_node_arg) < end_idx): continue # compute op in loop elif cur_prepose_node_arg in all_node_info: if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None: tmp_next_prepose_nodes.append(cur_prepose_node_arg) else: prepose_flag = False break # non compute op else: tmp_next_prepose_nodes.append(cur_prepose_node_arg) tmp_cur_prepose_nodes = tmp_next_prepose_nodes if prepose_flag == False: maybe_prepose_nodes.remove(maybe_prepose_nodes[0]) continue else: for n in tmp_cur_related_prepose_nodes: if n not in prepose_nodes: prepose_nodes.append(n) if n in maybe_prepose_nodes: maybe_prepose_nodes.remove(n) # sort by index prepose_nodes.sort(key=lambda x: self.node_mgr.find_node_idx(x)) chunk_info["args"]["prepose_nodes"] = prepose_nodes def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): # we need to log input nodes to avoid deleteing them in the loop chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1) # also need to get some prepose node's arg out of non_chunk_inputs for n in chunk_info["args"]["prepose_nodes"]: chunk_node_list.remove(n) non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list) for i in non_chunk_inputs: if i not in chunk_info["inputs"]: chunk_info["inputs_non_chunk"].append(i) return chunk_info def flow_search(self, start_idx, start_dim, end_idx, end_dim): inputs, outputs = find_chunk_compute_input_and_output_nodes( self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)) # get every node's chunk dim and fix dim all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) if all_node_info is None: return None chunk_info = { "region": (start_idx, end_idx), "inputs": [], "inputs_non_chunk": [], "inputs_dim": [], "outputs": [self.node_mgr.get_node_by_idx(end_idx)], "outputs_non_tensor": {}, "outputs_dim": [end_dim], "node_chunk_dim": all_node_info, "args": {}, } # find chunk info for other outputs if len(find_tensor_shape_node(outputs)) > 1: chunk_info = self._get_other_output_info(outputs, start_idx, start_dim, end_idx, end_dim, chunk_info) if chunk_info is None: return None # get input nodes' chunk dim inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info) if inputs is None: return None chunk_info["inputs"] = inputs chunk_info["inputs_dim"] = inputs_dim # move useless nodes ahead of loop self._get_prepose_nodes(all_node_info, start_idx, end_idx, chunk_info) # find non chunk inputs chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx) # reassgin reshape size, some size may have changed due to chunk chunk_info = self._reassgin_reshape_size(chunk_info) return chunk_info def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, chunk_info: Dict): start_node = self.node_mgr.get_node_by_idx(start_idx) # loop all outputs for output in outputs: output_legal = False output_idx = self.node_mgr.find_node_idx(output) # skip the origin output if output_idx == end_idx: continue # skip non tensor if get_node_shape(output) is None: # log shape tensor if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int): chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out']) continue # loop every dim of outputs, try to find a legal one for output_dim in range(len(get_node_shape(output))): if not self.check_region_start_end(start_node, start_dim, start_idx, output, output_dim, output_idx): continue new_all_node_info = self._get_all_node_info(output_dim, start_idx, output_idx) if new_all_node_info is None: continue # check node info legal if self._update_chunk_info(chunk_info, new_all_node_info, output, output_dim) == True: output_legal = True break # not legal if output_legal == False: return None return chunk_info def _update_chunk_info(self, chunk_info: Dict, new_all_node_info: Dict, output: Node, output_dim: int) -> bool: """ check if there is conflict between new node info and old chunk info. If not, update old chunk info """ # check if conflict overlap_flag = False for k, v in new_all_node_info.items(): if k in chunk_info["node_chunk_dim"]: overlap_flag = True if chunk_info["node_chunk_dim"][k]["chunk_dim"] != v["chunk_dim"]: return False # if no overlap, we just consider them as prepose nodes, instead of new output if overlap_flag == False: return True # update chunk info for k, v in new_all_node_info.items(): if k in chunk_info["node_chunk_dim"]: chunk_info["node_chunk_dim"][k]["fix_dim"] = list( set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"])) else: chunk_info["node_chunk_dim"][k] = v chunk_info["outputs"].append(output) chunk_info["outputs_dim"].append(output_dim) return True def _reassgin_reshape_size(self, chunk_info): """ Some shape args in reshape may have changed due to chunk reassgin those changed shape """ chunk_region = chunk_info["region"] reshape_size = {} chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"][0]] for node in self.node_mgr.get_node_slice_by_idx(chunk_region[0], chunk_region[1] + 1): if any(i == get_node_name(node) for i in ["reshape", "view"]): if node in chunk_info["args"]["prepose_nodes"]: continue if node.args[0] in chunk_info["inputs_non_chunk"]: continue reshape_args = flat_list(node.args[1:]) if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len( reshape_args[0].meta['fwd_out']) > 1: continue chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] new_shape = "" for reshape_arg_dim, reshape_arg in enumerate(reshape_args): if reshape_arg_dim == chunk_dim: new_shape += "min(chunk_size, %d - chunk_idx), " % chunk_shape else: if isinstance(reshape_arg, int): new_shape += "%s, " % str(reshape_arg) else: new_shape += "%s, " % reshape_arg.name new_shape = new_shape[:-2] origin_shape = str(reshape_args)[1:-1] reshape_size[node.name] = [origin_shape, new_shape] chunk_info["reshape_size"] = reshape_size return chunk_info def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, end_idx: int) -> bool: """ check if region start and end is legal """ # dim cannot be None if (get_node_shape(end_node) is None or get_node_shape(start_node) is None): return False # dim size cannot be 1 if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1): return False # must have users if len(end_node.users) == 0: return False # check index source align if not self.check_index_source(start_dim, start_node, start_idx, end_dim, end_node): return False # check index copmute if not self.check_index_compute(start_idx, end_dim, end_node, end_idx): return False return True