From a6cdbf9161afc526d3a961708c0b202ca18c3e7e Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 17:24:23 +0800 Subject: [PATCH] seperate trace flow --- colossalai/autochunk/autochunk_codegen.py | 2 +- colossalai/autochunk/search_chunk.py | 53 +-- colossalai/autochunk/select_chunk.py | 3 +- colossalai/autochunk/trace_flow.py | 414 ++++++++++++++++++++ colossalai/autochunk/trace_index.py | 395 ------------------- tests/test_autochunk/benchmark_autochunk.py | 4 +- 6 files changed, 447 insertions(+), 424 deletions(-) create mode 100644 colossalai/autochunk/trace_flow.py diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 3bb2e83be..39728cb79 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -167,7 +167,7 @@ def emit_code_with_chunk( ) # ones like if "ones_like" in node.name: - meta_node = chunk_region_search.index_tracer.node_list[node_idx] + meta_node = chunk_region_search.trace_index.node_list[node_idx] chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][ "chunk_dim" ] diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 5c58bda0c..030b13bdb 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -1,8 +1,10 @@ import copy from .select_chunk import SelectChunk -from .trace_index import TraceIndex, ReorderGraph +from .trace_index import TraceIndex +from .reorder_graph import ReorderGraph from .estiamte_memory import EstimateMemory +from .trace_flow import TraceFlow from .utils import ( get_node_shape, is_non_compute_node, @@ -14,12 +16,13 @@ class SearchChunk(object): def __init__(self, gm, max_memory=None, print_mem=False) -> None: self.gm = gm self.print_mem = print_mem - self.index_tracer = TraceIndex(list(gm.graph.nodes)) - self.index_tracer.trace_index() - self.reorder_graph = ReorderGraph(self.index_tracer) - self.memory_estimator = EstimateMemory() - self.chunk_selector = SelectChunk( - self.index_tracer, self.memory_estimator, self.reorder_graph, max_memory=max_memory + self.trace_index = TraceIndex(list(gm.graph.nodes)) + self.trace_index.trace_index() + self.trace_flow = TraceFlow(self.trace_index) + self.reorder_graph = ReorderGraph(self.trace_index) + self.estimate_memory = EstimateMemory() + self.select_chunk = SelectChunk( + self.trace_index, self.estimate_memory, self.reorder_graph, max_memory=max_memory ) def _find_peak_node(self, mem_peak): @@ -29,7 +32,7 @@ class SearchChunk(object): def _get_free_var(self): free_var_idx = [] - for idx, n in enumerate(self.index_tracer.node_list): + for idx, n in enumerate(self.trace_index.node_list): if n.op == "placeholder": free_var_idx.append(idx) return free_var_idx @@ -99,7 +102,7 @@ class SearchChunk(object): def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): start_traces = input_trace[start_idx] end_trace = output_trace[end_idx] - end_node = self.index_tracer.node_list[end_idx] + end_node = self.trace_index.node_list[end_idx] chunk_infos = [] for end_dim, _ in enumerate(end_trace["idx"]): if len(start_traces) > 1: @@ -113,46 +116,46 @@ class SearchChunk(object): ): continue # check index source align - if not self.index_tracer.check_index_source( + if not self.trace_flow.check_index_source( start_dim, start_node, start_idx, end_dim, end_node ): continue # check index copmute - if not self.index_tracer.check_index_compute( + if not self.trace_flow.check_index_compute( start_idx, end_dim, end_node, end_idx ): continue # flow search - chunk_info = self.index_tracer.flow_search( + chunk_info = self.trace_flow.flow_search( start_idx, start_dim, end_idx, end_dim ) if chunk_info is None: continue # check index copmute - if not self.index_tracer.check_index_duplicate(chunk_info): + if not self.trace_flow.check_index_duplicate(chunk_info): continue chunk_infos.append(chunk_info) return chunk_infos def _search_possible_chunk_regions(self, max_chunk_region, peak_node): possible_chunk_region = [] - output_trace = copy.deepcopy(self.index_tracer.idx_trace_list) + output_trace = copy.deepcopy(self.trace_index.idx_trace_list) input_trace = [] # trace of a node's input nodes - for _, n in enumerate(self.index_tracer.node_list): + for _, n in enumerate(self.trace_index.node_list): cur_trace = {} for arg in n.args: if type(arg) == type(n) and not is_non_compute_node_except_placeholder( arg ): - cur_trace[arg] = self.index_tracer._find_trace_from_node(arg) + cur_trace[arg] = self.trace_index._find_trace_from_node(arg) input_trace.append(cur_trace) for start_idx in range(max_chunk_region[0], peak_node + 1): for end_idx in range(peak_node, max_chunk_region[1] + 1): # skip non compute nodes if is_non_compute_node( - self.index_tracer.node_list[start_idx] - ) or is_non_compute_node(self.index_tracer.node_list[end_idx]): + self.trace_index.node_list[start_idx] + ) or is_non_compute_node(self.trace_index.node_list[end_idx]): continue # select free dim @@ -173,7 +176,7 @@ class SearchChunk(object): possible_chunk_regions = self._search_possible_chunk_regions( max_chunk_region, peak_node ) - best_chunk_region = self.chunk_selector._select_best_chunk_region( + best_chunk_region = self.select_chunk._select_best_chunk_region( possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak ) best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region) @@ -191,8 +194,8 @@ class SearchChunk(object): init_mem_peak, _, active_node, - ) = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list + ) = self.estimate_memory.estimate_chunk_inference_mem( + self.trace_index.node_list ) mem_peak = init_mem_peak @@ -206,14 +209,14 @@ class SearchChunk(object): mem_peak, _, active_node, - ) = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, chunk_infos + ) = self.estimate_memory.estimate_chunk_inference_mem( + self.trace_index.node_list, chunk_infos ) if self._stop_search(init_mem_peak, mem_peak): break if self.print_mem: self.print_mem = False - self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, chunk_infos, print_mem=True + self.estimate_memory.estimate_chunk_inference_mem( + self.trace_index.node_list, chunk_infos, print_mem=True ) return chunk_infos diff --git a/colossalai/autochunk/select_chunk.py b/colossalai/autochunk/select_chunk.py index f0262f1e5..30f4226f5 100644 --- a/colossalai/autochunk/select_chunk.py +++ b/colossalai/autochunk/select_chunk.py @@ -1,4 +1,5 @@ -from .trace_index import TraceIndex, ReorderGraph +from .trace_index import TraceIndex +from .reorder_graph import ReorderGraph from .estiamte_memory import EstimateMemory from .utils import is_non_compute_node diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py new file mode 100644 index 000000000..f372fa913 --- /dev/null +++ b/colossalai/autochunk/trace_flow.py @@ -0,0 +1,414 @@ +from .trace_index import TraceIndex +from .utils import ( + find_chunk_all_input_nodes, + find_chunk_compute_input_and_output_nodes, + find_idx_by_name, + get_node_shape, + is_non_compute_node, + is_non_compute_node_except_placeholder, +) + + +class TraceFlow(object): + def __init__(self, trace_index: TraceIndex) -> None: + self.trace_index = trace_index + + def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node): + """ + Check 2 given index: one index should be source of the other + Args: + start_idx(int): start node chunk dim + start_node(node): start node + end_idx(int): end node chunk dim + end_node(node): end node + + Returns: + bool: True if check pass + """ + start_node_idx = find_idx_by_name(start_node.name, self.trace_index.node_list) + end_node_trace = self.trace_index._find_trace_from_node(end_node) + end_node_trace_source = end_node_trace["source"][end_dim] + sorted_source = sorted( + end_node_trace_source.items(), key=lambda d: d[0], reverse=True + ) + for node_idx, node_dim in sorted_source: + if node_idx == start_node_idx and start_dim in node_dim: + return True + # it means we meet a node outside the loop, and the node is not input node + if node_idx < start_idx: + return False + return False + + def check_index_compute(self, start_idx, end_dim, end_node, end_idx): + """ + Check 2 given index: check they haven't been computed in the source trace. + Args: + start_idx(int): start node chunk dim + start_node(node): start node + end_idx(int): end node chunk dim + end_node(node): end node + + Returns: + bool: True if check pass + """ + end_node_trace = self.trace_index._find_trace_from_node(end_node) + end_node_compute = end_node_trace["compute"][end_dim] + if any(start_idx <= i <= end_idx for i in end_node_compute): + return False + return True + + def get_node_chunk_dim(self, node_from, node_from_dim, node_to): + node_from_source = self.trace_index._find_source_trace_from_node(node_from) + dim_source = node_from_source[node_from_dim] + node_to_idx = find_idx_by_name(node_to.name, self.trace_index.node_list) + for k, v in dim_source.items(): + if k == node_to_idx: + return v + return None + + def _find_inherit_dim(self, input_node, input_dim, node): + input_node_idx = find_idx_by_name(input_node.name, self.trace_index.node_list) + node_trace_source = self.trace_index._find_source_trace_from_node(node) + for node_dim in range(len(get_node_shape(node))): + if ( + input_node_idx in node_trace_source[node_dim] + and input_dim[0] in node_trace_source[node_dim][input_node_idx] + ): + return node_dim + return None + + def check_index_duplicate(self, chunk_infos, return_dim=False): + input_dim_after_node = {} + for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): + for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): + inherit_dim = self._find_inherit_dim(input_node, v, self.trace_index.node_list[k]) + if inherit_dim: + input_dim_after_node[k] = inherit_dim + + for node in self.trace_index.node_list[ + chunk_infos["region"][0] : chunk_infos["region"][1] + 1 + ]: + if is_non_compute_node_except_placeholder(node): + continue + count = 0 + duplicate_dims = [] + node_trace_source = self.trace_index._find_source_trace_from_node(node) + for node_dim in range(len(get_node_shape(node))): + duplicate_dim = [] + duplicate_flag = False + dim_source = node_trace_source[node_dim] + for k, v in dim_source.items(): + if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]: + if k in input_dim_after_node and input_dim_after_node[k] in v: + duplicate_flag = True + duplicate_dim.append((k, v)) + duplicate_dims.append(duplicate_dim) + if duplicate_flag: + count += 1 + + if count > 1: + if return_dim: + return False, duplicate_dims + else: + return False + if return_dim: + return True, None + else: + return True + + def _assgin_single_node_flow( + self, + arg_node, + start_idx, + end_idx, + cur_node_dim, + cur_node_compute, + cur_node_source, + cur_node_fix_dim, + all_node_info, + next_node_list, + ): + arg_idx = find_idx_by_name(arg_node.name, self.trace_index.node_list) + # arg in chunk range or be inputs + if not (start_idx <= arg_idx < end_idx): + return True + + # find arg dim + if cur_node_dim is not None: + # dim is computed + if arg_idx in cur_node_compute[cur_node_dim]: + return False + if arg_idx not in cur_node_source[cur_node_dim]: + arg_dim = None + else: + arg_dim = cur_node_source[cur_node_dim][arg_idx][0] + else: + arg_dim = None + + # get fix dim + arg_fix_dim = [] + if cur_node_dim is not None: + for i in cur_node_fix_dim: + fix_dim_source = cur_node_source[i] + if arg_idx in fix_dim_source: + arg_fix_dim.append(fix_dim_source[arg_idx][0]) + + # if already in node_info, arg dim must be same + if arg_node in all_node_info: + if all_node_info[arg_node]["chunk_dim"] != arg_dim: + return False + all_node_info[arg_node]["fix_dim"] = list( + set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim) + ) + # else add it to list + else: + all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim} + + next_node_list.append(arg_node) + return True + + def _get_all_node_info(self, end_dim, start_idx, end_idx): + cur_node_list = [ + self.trace_index.node_list[end_idx] + ] # start from the last node + all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} + + while len(cur_node_list) > 0: + next_node_list = [] + + for cur_node in cur_node_list: + # get cur node info + cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] + cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] + if cur_node_chunk_dim: + cur_node_compute = self.trace_index._find_compute_trace_from_node( + cur_node + ) + cur_node_source = self.trace_index._find_source_trace_from_node( + cur_node + ) + else: + cur_node_compute = cur_node_source = None + + # get all valid args + arg_list = [] + for arg in cur_node.args: + if type(arg) != type(cur_node): + continue + if is_non_compute_node(arg): + continue + arg_list.append(arg) + flow_flag = self._assgin_single_node_flow( + arg, + start_idx, + end_idx, + cur_node_chunk_dim, + cur_node_compute, + cur_node_source, + cur_node_fix_dim, + all_node_info, + next_node_list, + ) + if flow_flag == False: + return None + + if len(arg_list) == 2: + if any(i in cur_node.name for i in ["add", "mul"]): + for arg in arg_list: + if not ( + start_idx + <= find_idx_by_name(arg.name, self.trace_index.node_list) + < end_idx + ): + continue + arg_chunk_dim = all_node_info[arg]["chunk_dim"] + arg_fix_dim = all_node_info[arg]["fix_dim"] + arg_shape = get_node_shape(arg) + # add all dim as fix dim except chunk dim + for i, shape in enumerate(arg_shape): + if shape != 1 and i != cur_node_chunk_dim: + if i == arg_chunk_dim: + return None + if i not in arg_fix_dim: + arg_fix_dim.append(i) + elif "einsum" in cur_node.name: + pass + elif "matmul" in cur_node.name: + pass + else: + raise NotImplementedError() + cur_node_list = next_node_list + return all_node_info + + def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info): + inputs_dim = [] + remove_inputs = [] + for input_node in inputs: + input_dict = {} + input_node_idx = find_idx_by_name( + input_node.name, self.trace_index.node_list + ) + for user in input_node.users.keys(): + if is_non_compute_node(user): + continue + user_idx = find_idx_by_name(user.name, self.trace_index.node_list) + if start_idx <= user_idx <= end_idx: + chunk_dim = all_node_info[user]["chunk_dim"] + if chunk_dim is not None: + user_source = self.trace_index._find_source_trace_from_node(user)[chunk_dim] + if input_node_idx in user_source: + input_dict[user_idx] = user_source[input_node_idx] + else: + return None, None + if len(input_dict) == 0: + remove_inputs.append(input_node) + else: + inputs_dim.append(input_dict) + for i in remove_inputs: + if i in inputs: + inputs.remove(i) + return inputs, inputs_dim + + def _get_prepose_nodes(self, all_node_info, start_idx, end_idx): + # get all possible prepose nodes + maybe_prepose_nodes = [] + for node, node_info in all_node_info.items(): + if node_info["chunk_dim"] is None: + maybe_prepose_nodes.append(node) + maybe_prepose_nodes.sort( + key=lambda x: find_idx_by_name(x.name, self.trace_index.node_list), + reverse=True, + ) # from last node to first node + prepose_nodes = [] + # set every node as root, search its args, if all legal, turn root and args as prepose nodes + while len(maybe_prepose_nodes) > 0: + tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]] + tmp_cur_related_prepose_nodes = [] + prepose_flag = True + + # loop cur node's all arg until out of chunk + while len(tmp_cur_prepose_nodes) > 0: + if prepose_flag == False: + break + tmp_next_prepose_nodes = [] + tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes) + for cur_prepose_node in tmp_cur_prepose_nodes: + if prepose_flag == False: + break + for cur_prepose_node_arg in cur_prepose_node.args: + if type(cur_prepose_node_arg) != type(cur_prepose_node): + continue + # out of loop + if not ( + start_idx + <= find_idx_by_name( + cur_prepose_node_arg.name, self.trace_index.node_list + ) + < end_idx + ): + continue + # compute op in loop + elif cur_prepose_node_arg in all_node_info: + if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None: + tmp_next_prepose_nodes.append(cur_prepose_node_arg) + else: + prepose_flag = False + break + # non compute op + else: + tmp_next_prepose_nodes.append(cur_prepose_node_arg) + tmp_cur_prepose_nodes = tmp_next_prepose_nodes + + if prepose_flag == False: + maybe_prepose_nodes.remove(maybe_prepose_nodes[0]) + continue + else: + for n in tmp_cur_related_prepose_nodes: + if n not in prepose_nodes: + prepose_nodes.append(n) + if n in maybe_prepose_nodes: + maybe_prepose_nodes.remove(n) + # sort by index + prepose_nodes.sort( + key=lambda x: find_idx_by_name(x.name, self.trace_index.node_list) + ) + + return prepose_nodes + + def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): + # we need to log input nodes to avoid deleteing them in the loop + chunk_node_list = self.trace_index.node_list[start_idx : end_idx + 1] + # also need to get some prepose node's arg out of non_chunk_inputs + for n in chunk_info["args"]["prepose_nodes"]: + chunk_node_list.remove(n) + non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list) + for i in non_chunk_inputs: + if i not in chunk_info["inputs"]: + chunk_info["inputs_non_chunk"].append(i) + return chunk_info + + def flow_search(self, start_idx, start_dim, end_idx, end_dim): + inputs, outputs = find_chunk_compute_input_and_output_nodes( + self.trace_index.node_list[start_idx : end_idx + 1] + ) + # only single ouput + if len(outputs) > 1: + return None + + # get every node's chunk dim and fix dim + all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) + if all_node_info is None: + return None + + # get input nodes' chunk dim + inputs, inputs_dim = self._get_input_nodes_dim( + inputs, start_idx, end_idx, all_node_info + ) + if inputs is None: + return None + + chunk_info = { + "region": (start_idx, end_idx), + "inputs": inputs, + "inputs_non_chunk": [], + "inputs_dim": inputs_dim, + "outputs": outputs, + "outputs_dim": end_dim, + "node_chunk_dim": all_node_info, + "args": {}, + } + + # move useless nodes ahead of loop + chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes( + all_node_info, start_idx, end_idx + ) + + # find non chunk inputs + chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx) + + # reassgin reshape size, some size may have changed due to chunk + chunk_info = self._reassgin_reshape_size(chunk_info) + + return chunk_info + + def _reassgin_reshape_size(self, chunk_info): + chunk_region = chunk_info["region"] + reshape_size = {} + chunk_shape = get_node_shape(chunk_info["outputs"][0])[ + chunk_info["outputs_dim"] + ] + for node in self.trace_index.node_list[chunk_region[0] : chunk_region[1] + 1]: + if any(i in node.name for i in ["reshape", "view"]): + reshape_args = node.args[1:] + reshape_log = self.trace_index.idx_view_list[node] + chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] + reshape_size[node.name] = {} + for reshape_arg_dim, reshape_arg in enumerate(reshape_args): + if reshape_arg_dim in reshape_log["dim_to"]: + continue + if reshape_arg_dim == chunk_dim: + reshape_size[node.name][reshape_arg.name] = ( + "min(chunk_size, %d - chunk_idx)" % chunk_shape + ) + chunk_info["reshape_size"] = reshape_size + return chunk_info diff --git a/colossalai/autochunk/trace_index.py b/colossalai/autochunk/trace_index.py index 3ac0d7f84..1e8969d87 100644 --- a/colossalai/autochunk/trace_index.py +++ b/colossalai/autochunk/trace_index.py @@ -1,12 +1,8 @@ import copy from .utils import ( - find_chunk_all_input_nodes, - find_chunk_compute_input_and_output_nodes, find_idx_by_name, get_node_shape, - is_non_compute_node, - is_non_compute_node_except_placeholder, ) @@ -588,394 +584,3 @@ class TraceIndex(object): continue else: raise NotImplementedError(node.op, "op not implemented yet!") - # self._merge_equal_idx() - - def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node): - """ - Check 2 given index: one index should be source of the other - Args: - start_idx(int): start node chunk dim - start_node(node): start node - end_idx(int): end node chunk dim - end_node(node): end node - - Returns: - bool: True if check pass - """ - start_node_idx = find_idx_by_name(start_node.name, self.node_list) - end_node_trace = self._find_trace_from_node(end_node) - end_node_trace_source = end_node_trace["source"][end_dim] - sorted_source = sorted( - end_node_trace_source.items(), key=lambda d: d[0], reverse=True - ) - for node_idx, node_dim in sorted_source: - if node_idx == start_node_idx and start_dim in node_dim: - return True - # it means we meet a node outside the loop, and the node is not input node - if node_idx < start_idx: - return False - return False - - def check_index_compute(self, start_idx, end_dim, end_node, end_idx): - """ - Check 2 given index: check they haven't been computed in the source trace. - Args: - start_idx(int): start node chunk dim - start_node(node): start node - end_idx(int): end node chunk dim - end_node(node): end node - - Returns: - bool: True if check pass - """ - end_node_trace = self._find_trace_from_node(end_node) - end_node_compute = end_node_trace["compute"][end_dim] - if any(start_idx <= i <= end_idx for i in end_node_compute): - return False - return True - - def get_node_chunk_dim(self, node_from, node_from_dim, node_to): - node_from_source = self._find_source_trace_from_node(node_from) - dim_source = node_from_source[node_from_dim] - node_to_idx = find_idx_by_name(node_to.name, self.node_list) - for k, v in dim_source.items(): - if k == node_to_idx: - return v - return None - - def _find_inherit_dim(self, input_node, input_dim, node): - input_node_idx = find_idx_by_name(input_node.name, self.node_list) - node_trace_source = self._find_source_trace_from_node(node) - for node_dim in range(len(get_node_shape(node))): - if ( - input_node_idx in node_trace_source[node_dim] - and input_dim[0] in node_trace_source[node_dim][input_node_idx] - ): - return node_dim - return None - - def check_index_duplicate(self, chunk_infos, return_dim=False): - input_dim_after_node = {} - for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): - for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): - inherit_dim = self._find_inherit_dim(input_node, v, self.node_list[k]) - if inherit_dim: - input_dim_after_node[k] = inherit_dim - - for node in self.node_list[ - chunk_infos["region"][0] : chunk_infos["region"][1] + 1 - ]: - if is_non_compute_node_except_placeholder(node): - continue - count = 0 - duplicate_dims = [] - node_trace_source = self._find_source_trace_from_node(node) - for node_dim in range(len(get_node_shape(node))): - duplicate_dim = [] - duplicate_flag = False - dim_source = node_trace_source[node_dim] - for k, v in dim_source.items(): - if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]: - if k in input_dim_after_node and input_dim_after_node[k] in v: - duplicate_flag = True - duplicate_dim.append((k, v)) - duplicate_dims.append(duplicate_dim) - if duplicate_flag: - count += 1 - - if count > 1: - if return_dim: - return False, duplicate_dims - else: - return False - if return_dim: - return True, None - else: - return True - - def _assgin_single_node_flow( - self, - arg_node, - start_idx, - end_idx, - cur_node_dim, - cur_node_compute, - cur_node_source, - cur_node_fix_dim, - all_node_info, - next_node_list, - ): - arg_idx = find_idx_by_name(arg_node.name, self.node_list) - # arg in chunk range or be inputs - if not (start_idx <= arg_idx < end_idx): - return True - - # find arg dim - if cur_node_dim is not None: - # dim is computed - if arg_idx in cur_node_compute[cur_node_dim]: - return False - if arg_idx not in cur_node_source[cur_node_dim]: - arg_dim = None - else: - arg_dim = cur_node_source[cur_node_dim][arg_idx][0] - else: - arg_dim = None - - # get fix dim - arg_fix_dim = [] - if cur_node_dim is not None: - for i in cur_node_fix_dim: - fix_dim_source = cur_node_source[i] - if arg_idx in fix_dim_source: - arg_fix_dim.append(fix_dim_source[arg_idx][0]) - - # if already in node_info, arg dim must be same - if arg_node in all_node_info: - if all_node_info[arg_node]["chunk_dim"] != arg_dim: - return False - all_node_info[arg_node]["fix_dim"] = list( - set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim) - ) - # else add it to list - else: - all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim} - - next_node_list.append(arg_node) - return True - - def _get_all_node_info(self, end_dim, start_idx, end_idx): - cur_node_list = [self.node_list[end_idx]] # start from the last node - all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} - - while len(cur_node_list) > 0: - next_node_list = [] - - for cur_node in cur_node_list: - # get cur node info - cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] - cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] - if cur_node_chunk_dim: - cur_node_compute = self._find_compute_trace_from_node(cur_node) - cur_node_source = self._find_source_trace_from_node(cur_node) - else: - cur_node_compute = cur_node_source = None - - # get all valid args - arg_list = [] - for arg in cur_node.args: - if type(arg) != type(cur_node): - continue - if is_non_compute_node(arg): - continue - arg_list.append(arg) - flow_flag = self._assgin_single_node_flow( - arg, - start_idx, - end_idx, - cur_node_chunk_dim, - cur_node_compute, - cur_node_source, - cur_node_fix_dim, - all_node_info, - next_node_list, - ) - if flow_flag == False: - return None - - if len(arg_list) == 2: - if any(i in cur_node.name for i in ["add", "mul"]): - for arg in arg_list: - if not ( - start_idx - <= find_idx_by_name(arg.name, self.node_list) - < end_idx - ): - continue - arg_chunk_dim = all_node_info[arg]["chunk_dim"] - arg_fix_dim = all_node_info[arg]["fix_dim"] - arg_shape = get_node_shape(arg) - # add all dim as fix dim except chunk dim - for i, shape in enumerate(arg_shape): - if shape != 1 and i != cur_node_chunk_dim: - if i == arg_chunk_dim: - return None - if i not in arg_fix_dim: - arg_fix_dim.append(i) - elif "einsum" in cur_node.name: - pass - elif "matmul" in cur_node.name: - pass - else: - raise NotImplementedError() - cur_node_list = next_node_list - return all_node_info - - def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info): - inputs_dim = [] - remove_inputs = [] - for input_node in inputs: - input_dict = {} - input_node_idx = find_idx_by_name(input_node.name, self.node_list) - for user in input_node.users.keys(): - if is_non_compute_node(user): - continue - user_idx = find_idx_by_name(user.name, self.node_list) - if start_idx <= user_idx <= end_idx: - chunk_dim = all_node_info[user]["chunk_dim"] - if chunk_dim is not None: - user_source = self._find_source_trace_from_node(user)[chunk_dim] - if input_node_idx in user_source: - input_dict[user_idx] = user_source[input_node_idx] - else: - return None, None - if len(input_dict) == 0: - remove_inputs.append(input_node) - else: - inputs_dim.append(input_dict) - for i in remove_inputs: - if i in inputs: - inputs.remove(i) - return inputs, inputs_dim - - def _get_prepose_nodes(self, all_node_info, start_idx, end_idx): - # get all possible prepose nodes - maybe_prepose_nodes = [] - for node, node_info in all_node_info.items(): - if node_info["chunk_dim"] is None: - maybe_prepose_nodes.append(node) - maybe_prepose_nodes.sort( - key=lambda x: find_idx_by_name(x.name, self.node_list), - reverse=True, - ) # from last node to first node - prepose_nodes = [] - # set every node as root, search its args, if all legal, turn root and args as prepose nodes - while len(maybe_prepose_nodes) > 0: - tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]] - tmp_cur_related_prepose_nodes = [] - prepose_flag = True - - # loop cur node's all arg until out of chunk - while len(tmp_cur_prepose_nodes) > 0: - if prepose_flag == False: - break - tmp_next_prepose_nodes = [] - tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes) - for cur_prepose_node in tmp_cur_prepose_nodes: - if prepose_flag == False: - break - for cur_prepose_node_arg in cur_prepose_node.args: - if type(cur_prepose_node_arg) != type(cur_prepose_node): - continue - # out of loop - if not ( - start_idx - <= find_idx_by_name( - cur_prepose_node_arg.name, self.node_list - ) - < end_idx - ): - continue - # compute op in loop - elif cur_prepose_node_arg in all_node_info: - if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None: - tmp_next_prepose_nodes.append(cur_prepose_node_arg) - else: - prepose_flag = False - break - # non compute op - else: - tmp_next_prepose_nodes.append(cur_prepose_node_arg) - tmp_cur_prepose_nodes = tmp_next_prepose_nodes - - if prepose_flag == False: - maybe_prepose_nodes.remove(maybe_prepose_nodes[0]) - continue - else: - for n in tmp_cur_related_prepose_nodes: - if n not in prepose_nodes: - prepose_nodes.append(n) - if n in maybe_prepose_nodes: - maybe_prepose_nodes.remove(n) - # sort by index - prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.node_list)) - - return prepose_nodes - - def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): - # we need to log input nodes to avoid deleteing them in the loop - chunk_node_list = self.node_list[start_idx : end_idx + 1] - # also need to get some prepose node's arg out of non_chunk_inputs - for n in chunk_info["args"]["prepose_nodes"]: - chunk_node_list.remove(n) - non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list) - for i in non_chunk_inputs: - if i not in chunk_info["inputs"]: - chunk_info["inputs_non_chunk"].append(i) - return chunk_info - - def flow_search(self, start_idx, start_dim, end_idx, end_dim): - inputs, outputs = find_chunk_compute_input_and_output_nodes( - self.node_list[start_idx : end_idx + 1] - ) - # only single ouput - if len(outputs) > 1: - return None - - # get every node's chunk dim and fix dim - all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) - if all_node_info is None: - return None - - # get input nodes' chunk dim - inputs, inputs_dim = self._get_input_nodes_dim( - inputs, start_idx, end_idx, all_node_info - ) - if inputs is None: - return None - - chunk_info = { - "region": (start_idx, end_idx), - "inputs": inputs, - "inputs_non_chunk": [], - "inputs_dim": inputs_dim, - "outputs": outputs, - "outputs_dim": end_dim, - "node_chunk_dim": all_node_info, - "args": {}, - } - - # move useless nodes ahead of loop - chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes( - all_node_info, start_idx, end_idx - ) - - # find non chunk inputs - chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx) - - # reassgin reshape size, some size may have changed due to chunk - chunk_info = self._reassgin_reshape_size(chunk_info) - - return chunk_info - - def _reassgin_reshape_size(self, chunk_info): - chunk_region = chunk_info["region"] - reshape_size = {} - chunk_shape = get_node_shape(chunk_info["outputs"][0])[ - chunk_info["outputs_dim"] - ] - for node in self.node_list[chunk_region[0] : chunk_region[1] + 1]: - if any(i in node.name for i in ["reshape", "view"]): - reshape_args = node.args[1:] - reshape_log = self.idx_view_list[node] - chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] - reshape_size[node.name] = {} - for reshape_arg_dim, reshape_arg in enumerate(reshape_args): - if reshape_arg_dim in reshape_log["dim_to"]: - continue - if reshape_arg_dim == chunk_dim: - reshape_size[node.name][reshape_arg.name] = ( - "min(chunk_size, %d - chunk_idx)" % chunk_shape - ) - chunk_info["reshape_size"] = reshape_size - return chunk_info diff --git a/tests/test_autochunk/benchmark_autochunk.py b/tests/test_autochunk/benchmark_autochunk.py index 081f01368..7a9d8cdee 100644 --- a/tests/test_autochunk/benchmark_autochunk.py +++ b/tests/test_autochunk/benchmark_autochunk.py @@ -104,8 +104,8 @@ def benchmark_evoformer(): model = evoformer_base().cuda() # build autochunk model - # max_memory = 1000 # MB fit memory mode - max_memory = None # min memory mode + max_memory = 1000 # MB fit memory mode + # max_memory = None # min memory mode autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair) # build openfold