From 0ea903b94edb59df8e24ed86764197292f6345c5 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 17:25:13 +0800 Subject: [PATCH] rename trace_index to trace_indice --- colossalai/autochunk/autochunk_codegen.py | 4 +- colossalai/autochunk/reorder_graph.py | 32 +++++------ colossalai/autochunk/search_chunk.py | 32 +++++------ colossalai/autochunk/select_chunk.py | 22 ++++---- colossalai/autochunk/trace_flow.py | 56 +++++++++---------- .../{trace_index.py => trace_indice.py} | 2 +- 6 files changed, 74 insertions(+), 74 deletions(-) rename colossalai/autochunk/{trace_index.py => trace_indice.py} (99%) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index cc39e391e..6e0cfb9cb 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -94,9 +94,9 @@ def _replace_reshape_size(context, node_name, reshape_size_dict): return context -def _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body): +def _replace_ones_like(search_chunk: SearchChunk, chunk_infos, region_idx, node_idx, node, body): if "ones_like" in node.name: - meta_node = search_chunk.trace_index.node_list[node_idx] + meta_node = search_chunk.trace_indice.node_list[node_idx] chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] if get_node_shape(meta_node)[chunk_dim] != 1: source_node = meta_node.args[0].args[0] diff --git a/colossalai/autochunk/reorder_graph.py b/colossalai/autochunk/reorder_graph.py index bf4420eac..6baa0d2a7 100644 --- a/colossalai/autochunk/reorder_graph.py +++ b/colossalai/autochunk/reorder_graph.py @@ -1,22 +1,22 @@ -from .trace_index import TraceIndex +from .trace_indice import TraceIndice from .utils import find_idx_by_name class ReorderGraph(object): - def __init__(self, trace_index: TraceIndex) -> None: - self.trace_index = trace_index + def __init__(self, trace_indice: TraceIndice) -> None: + self.trace_indice = trace_indice self.all_reorder_map = { - i: i for i in range(len(self.trace_index.idx_trace_list)) + i: i for i in range(len(self.trace_indice.idx_trace_list)) } def _get_reorder_map(self, chunk_info): - reorder_map = {i: i for i in range(len(self.trace_index.node_list))} + reorder_map = {i: i for i in range(len(self.trace_indice.node_list))} chunk_region_start = chunk_info["region"][0] chunk_region_end = chunk_info["region"][1] chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] chunk_prepose_nodes_idx = [ - find_idx_by_name(i.name, self.trace_index.node_list) + find_idx_by_name(i.name, self.trace_indice.node_list) for i in chunk_prepose_nodes ] # put prepose nodes ahead @@ -24,10 +24,10 @@ class ReorderGraph(object): n_idx = chunk_prepose_nodes_idx[idx] reorder_map[n_idx] = chunk_region_start + idx # put other nodes after prepose nodes - for n in self.trace_index.node_list[chunk_region_start : chunk_region_end + 1]: + for n in self.trace_indice.node_list[chunk_region_start : chunk_region_end + 1]: if n in chunk_prepose_nodes: continue - n_idx = find_idx_by_name(n.name, self.trace_index.node_list) + n_idx = find_idx_by_name(n.name, self.trace_indice.node_list) pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) reorder_map[n_idx] = n_idx + pos @@ -53,25 +53,25 @@ class ReorderGraph(object): self.all_reorder_map[origin_idx] = reorder_map[map_idx] def _reorder_self_node_list(self, reorder_map): - new_node_list = [None for _ in range(len(self.trace_index.node_list))] + new_node_list = [None for _ in range(len(self.trace_indice.node_list))] for old_idx, new_idx in reorder_map.items(): - new_node_list[new_idx] = self.trace_index.node_list[old_idx] - self.trace_index.node_list = new_node_list + new_node_list[new_idx] = self.trace_indice.node_list[old_idx] + self.trace_indice.node_list = new_node_list def _reorder_idx_trace(self, reorder_map): # reorder list - new_idx_trace_list = [None for _ in range(len(self.trace_index.idx_trace_list))] + new_idx_trace_list = [None for _ in range(len(self.trace_indice.idx_trace_list))] for old_idx, new_idx in reorder_map.items(): - new_idx_trace_list[new_idx] = self.trace_index.idx_trace_list[old_idx] - self.trace_index.idx_trace_list = new_idx_trace_list + new_idx_trace_list[new_idx] = self.trace_indice.idx_trace_list[old_idx] + self.trace_indice.idx_trace_list = new_idx_trace_list # update compute - for idx_trace in self.trace_index.idx_trace_list: + for idx_trace in self.trace_indice.idx_trace_list: compute = idx_trace["compute"] for dim_compute in compute: for idx, i in enumerate(dim_compute): dim_compute[idx] = reorder_map[i] # update source - for idx_trace in self.trace_index.idx_trace_list: + for idx_trace in self.trace_indice.idx_trace_list: source = idx_trace["source"] for dim_idx, dim_source in enumerate(source): new_dim_source = {} diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index ff4c15878..d90e50927 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -7,7 +7,7 @@ from .estimate_memory import EstimateMemory from .reorder_graph import ReorderGraph from .select_chunk import SelectChunk from .trace_flow import TraceFlow -from .trace_index import TraceIndex +from .trace_indice import TraceIndice from .utils import ( get_node_shape, is_non_compute_node, @@ -47,13 +47,13 @@ class SearchChunk(object): def __init__(self, gm, max_memory=None, print_mem=False) -> None: self.gm = gm self.print_mem = print_mem - self.trace_index = TraceIndex(list(gm.graph.nodes)) - self.trace_index.trace_index() - self.trace_flow = TraceFlow(self.trace_index) - self.reorder_graph = ReorderGraph(self.trace_index) + self.trace_indice = TraceIndice(list(gm.graph.nodes)) + self.trace_indice.trace_index() + self.trace_flow = TraceFlow(self.trace_indice) + self.reorder_graph = ReorderGraph(self.trace_indice) self.estimate_memory = EstimateMemory() self.select_chunk = SelectChunk( - self.trace_index, + self.trace_indice, self.estimate_memory, self.reorder_graph, max_memory=max_memory, @@ -72,7 +72,7 @@ class SearchChunk(object): free_var_idx (List): all indexs of free vars """ free_var_idx = [] - for idx, n in enumerate(self.trace_index.node_list): + for idx, n in enumerate(self.trace_indice.node_list): if n.op == "placeholder": free_var_idx.append(idx) return free_var_idx @@ -156,7 +156,7 @@ class SearchChunk(object): """ start_traces = input_trace[start_idx] end_trace = output_trace[end_idx] - end_node = self.trace_index.node_list[end_idx] + end_node = self.trace_indice.node_list[end_idx] chunk_infos = [] for end_dim, _ in enumerate(end_trace["idx"]): if len(start_traces) > 1: @@ -205,23 +205,23 @@ class SearchChunk(object): possible_chunk_region (List) """ possible_chunk_region = [] - output_trace = copy.deepcopy(self.trace_index.idx_trace_list) + output_trace = copy.deepcopy(self.trace_indice.idx_trace_list) input_trace = [] # trace of a node's input nodes - for _, n in enumerate(self.trace_index.node_list): + for _, n in enumerate(self.trace_indice.node_list): cur_trace = {} for arg in n.args: if type(arg) == type(n) and not is_non_compute_node_except_placeholder( arg ): - cur_trace[arg] = self.trace_index._find_trace_from_node(arg) + cur_trace[arg] = self.trace_indice._find_trace_from_node(arg) input_trace.append(cur_trace) for start_idx in range(max_chunk_region[0], peak_node + 1): for end_idx in range(peak_node, max_chunk_region[1] + 1): # skip non compute nodes if is_non_compute_node( - self.trace_index.node_list[start_idx] - ) or is_non_compute_node(self.trace_index.node_list[end_idx]): + self.trace_indice.node_list[start_idx] + ) or is_non_compute_node(self.trace_indice.node_list[end_idx]): continue # select free dim @@ -292,7 +292,7 @@ class SearchChunk(object): _, active_node, ) = self.estimate_memory.estimate_chunk_inference_mem( - self.trace_index.node_list + self.trace_indice.node_list ) mem_peak = init_mem_peak @@ -307,13 +307,13 @@ class SearchChunk(object): _, active_node, ) = self.estimate_memory.estimate_chunk_inference_mem( - self.trace_index.node_list, chunk_infos + self.trace_indice.node_list, chunk_infos ) if self._stop_search(init_mem_peak, mem_peak): break if self.print_mem: self.print_mem = False self.estimate_memory.estimate_chunk_inference_mem( - self.trace_index.node_list, chunk_infos, print_mem=True + self.trace_indice.node_list, chunk_infos, print_mem=True ) return chunk_infos diff --git a/colossalai/autochunk/select_chunk.py b/colossalai/autochunk/select_chunk.py index 7127cfd64..f0612e45a 100644 --- a/colossalai/autochunk/select_chunk.py +++ b/colossalai/autochunk/select_chunk.py @@ -1,19 +1,19 @@ from .estimate_memory import EstimateMemory from .reorder_graph import ReorderGraph -from .trace_index import TraceIndex +from .trace_indice import TraceIndice from .utils import is_non_compute_node class SelectChunk(object): def __init__( self, - trace_index: TraceIndex, + trace_indice: TraceIndice, estimate_memory: EstimateMemory, reorder_graph: ReorderGraph, max_memory=None, ): - self.index_tracer = trace_index - self.memory_estimator = estimate_memory + self.trace_indice = trace_indice + self.estimate_memory = estimate_memory self.reorder_graph = reorder_graph if max_memory is not None: self.stratge = "fit_memory" @@ -68,10 +68,10 @@ class SelectChunk(object): for region in possible_chunk_regions: cur_region = region.copy() cur_node_list, cur_region = self.reorder_graph.tmp_reorder( - self.index_tracer.node_list, cur_region + self.trace_indice.node_list, cur_region ) cur_chunk_infos = chunk_infos + [cur_region] - cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( cur_node_list, cur_chunk_infos )[0] cur_chunk_region_peak = cur_mem_peak[ @@ -113,7 +113,7 @@ class SelectChunk(object): chunk_size *= 2 reorder_chunk_info["chunk_size"] = chunk_size cur_chunk_infos = chunk_infos + [reorder_chunk_info] - cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( chunk_region_dict["reorder_node_list"], cur_chunk_infos )[0] cur_chunk_max_mem = max( @@ -139,7 +139,7 @@ class SelectChunk(object): mid = int((left + right) / 2 + 0.5) chunk_info["chunk_size"] = mid cur_chunk_infos = chunk_infos + [chunk_info] - cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( chunk_region_dict["reorder_node_list"], cur_chunk_infos )[0] cur_chunk_max_mem = max( @@ -153,7 +153,7 @@ class SelectChunk(object): def _get_compute_node_num(self, start, end): count = 0 - for i in self.index_tracer.node_list[start : end + 1]: + for i in self.trace_indice.node_list[start : end + 1]: if not is_non_compute_node(i): count += 1 return count @@ -178,10 +178,10 @@ class SelectChunk(object): for region in possible_chunk_regions: cur_region = region.copy() cur_node_list, cur_region = self.reorder_graph.tmp_reorder( - self.index_tracer.node_list, cur_region + self.trace_indice.node_list, cur_region ) cur_chunk_infos = chunk_infos + [cur_region] - cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( cur_node_list, cur_chunk_infos )[0] cur_chunk_region_peak = cur_mem_peak[ diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py index 7139e7e04..33fade1a5 100644 --- a/colossalai/autochunk/trace_flow.py +++ b/colossalai/autochunk/trace_flow.py @@ -1,4 +1,4 @@ -from .trace_index import TraceIndex +from .trace_indice import TraceIndice from .utils import ( find_chunk_all_input_nodes, find_chunk_compute_input_and_output_nodes, @@ -10,8 +10,8 @@ from .utils import ( class TraceFlow(object): - def __init__(self, trace_index: TraceIndex) -> None: - self.trace_index = trace_index + def __init__(self, trace_indice: TraceIndice) -> None: + self.trace_indice = trace_indice def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node): """ @@ -25,8 +25,8 @@ class TraceFlow(object): Returns: bool: True if check pass """ - start_node_idx = find_idx_by_name(start_node.name, self.trace_index.node_list) - end_node_trace = self.trace_index._find_trace_from_node(end_node) + start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list) + end_node_trace = self.trace_indice._find_trace_from_node(end_node) end_node_trace_source = end_node_trace["source"][end_dim] sorted_source = sorted( end_node_trace_source.items(), key=lambda d: d[0], reverse=True @@ -51,24 +51,24 @@ class TraceFlow(object): Returns: bool: True if check pass """ - end_node_trace = self.trace_index._find_trace_from_node(end_node) + end_node_trace = self.trace_indice._find_trace_from_node(end_node) end_node_compute = end_node_trace["compute"][end_dim] if any(start_idx <= i <= end_idx for i in end_node_compute): return False return True def get_node_chunk_dim(self, node_from, node_from_dim, node_to): - node_from_source = self.trace_index._find_source_trace_from_node(node_from) + node_from_source = self.trace_indice._find_source_trace_from_node(node_from) dim_source = node_from_source[node_from_dim] - node_to_idx = find_idx_by_name(node_to.name, self.trace_index.node_list) + node_to_idx = find_idx_by_name(node_to.name, self.trace_indice.node_list) for k, v in dim_source.items(): if k == node_to_idx: return v return None def _find_inherit_dim(self, input_node, input_dim, node): - input_node_idx = find_idx_by_name(input_node.name, self.trace_index.node_list) - node_trace_source = self.trace_index._find_source_trace_from_node(node) + input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list) + node_trace_source = self.trace_indice._find_source_trace_from_node(node) for node_dim in range(len(get_node_shape(node))): if ( input_node_idx in node_trace_source[node_dim] @@ -82,19 +82,19 @@ class TraceFlow(object): for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): inherit_dim = self._find_inherit_dim( - input_node, v, self.trace_index.node_list[k] + input_node, v, self.trace_indice.node_list[k] ) if inherit_dim: input_dim_after_node[k] = inherit_dim - for node in self.trace_index.node_list[ + for node in self.trace_indice.node_list[ chunk_infos["region"][0] : chunk_infos["region"][1] + 1 ]: if is_non_compute_node_except_placeholder(node): continue count = 0 duplicate_dims = [] - node_trace_source = self.trace_index._find_source_trace_from_node(node) + node_trace_source = self.trace_indice._find_source_trace_from_node(node) for node_dim in range(len(get_node_shape(node))): duplicate_dim = [] duplicate_flag = False @@ -130,7 +130,7 @@ class TraceFlow(object): all_node_info, next_node_list, ): - arg_idx = find_idx_by_name(arg_node.name, self.trace_index.node_list) + arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list) # arg in chunk range or be inputs if not (start_idx <= arg_idx < end_idx): return True @@ -171,7 +171,7 @@ class TraceFlow(object): def _get_all_node_info(self, end_dim, start_idx, end_idx): cur_node_list = [ - self.trace_index.node_list[end_idx] + self.trace_indice.node_list[end_idx] ] # start from the last node all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} @@ -183,10 +183,10 @@ class TraceFlow(object): cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] if cur_node_chunk_dim: - cur_node_compute = self.trace_index._find_compute_trace_from_node( + cur_node_compute = self.trace_indice._find_compute_trace_from_node( cur_node ) - cur_node_source = self.trace_index._find_source_trace_from_node( + cur_node_source = self.trace_indice._find_source_trace_from_node( cur_node ) else: @@ -220,7 +220,7 @@ class TraceFlow(object): if not ( start_idx <= find_idx_by_name( - arg.name, self.trace_index.node_list + arg.name, self.trace_indice.node_list ) < end_idx ): @@ -250,16 +250,16 @@ class TraceFlow(object): for input_node in inputs: input_dict = {} input_node_idx = find_idx_by_name( - input_node.name, self.trace_index.node_list + input_node.name, self.trace_indice.node_list ) for user in input_node.users.keys(): if is_non_compute_node(user): continue - user_idx = find_idx_by_name(user.name, self.trace_index.node_list) + user_idx = find_idx_by_name(user.name, self.trace_indice.node_list) if start_idx <= user_idx <= end_idx: chunk_dim = all_node_info[user]["chunk_dim"] if chunk_dim is not None: - user_source = self.trace_index._find_source_trace_from_node( + user_source = self.trace_indice._find_source_trace_from_node( user )[chunk_dim] if input_node_idx in user_source: @@ -282,7 +282,7 @@ class TraceFlow(object): if node_info["chunk_dim"] is None: maybe_prepose_nodes.append(node) maybe_prepose_nodes.sort( - key=lambda x: find_idx_by_name(x.name, self.trace_index.node_list), + key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list), reverse=True, ) # from last node to first node prepose_nodes = [] @@ -308,7 +308,7 @@ class TraceFlow(object): if not ( start_idx <= find_idx_by_name( - cur_prepose_node_arg.name, self.trace_index.node_list + cur_prepose_node_arg.name, self.trace_indice.node_list ) < end_idx ): @@ -336,14 +336,14 @@ class TraceFlow(object): maybe_prepose_nodes.remove(n) # sort by index prepose_nodes.sort( - key=lambda x: find_idx_by_name(x.name, self.trace_index.node_list) + key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list) ) return prepose_nodes def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): # we need to log input nodes to avoid deleteing them in the loop - chunk_node_list = self.trace_index.node_list[start_idx : end_idx + 1] + chunk_node_list = self.trace_indice.node_list[start_idx : end_idx + 1] # also need to get some prepose node's arg out of non_chunk_inputs for n in chunk_info["args"]["prepose_nodes"]: chunk_node_list.remove(n) @@ -355,7 +355,7 @@ class TraceFlow(object): def flow_search(self, start_idx, start_dim, end_idx, end_dim): inputs, outputs = find_chunk_compute_input_and_output_nodes( - self.trace_index.node_list[start_idx : end_idx + 1] + self.trace_indice.node_list[start_idx : end_idx + 1] ) # only single ouput if len(outputs) > 1: @@ -403,10 +403,10 @@ class TraceFlow(object): chunk_shape = get_node_shape(chunk_info["outputs"][0])[ chunk_info["outputs_dim"] ] - for node in self.trace_index.node_list[chunk_region[0] : chunk_region[1] + 1]: + for node in self.trace_indice.node_list[chunk_region[0] : chunk_region[1] + 1]: if any(i in node.name for i in ["reshape", "view"]): reshape_args = node.args[1:] - reshape_log = self.trace_index.idx_view_list[node] + reshape_log = self.trace_indice.idx_view_list[node] chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] reshape_size[node.name] = {} for reshape_arg_dim, reshape_arg in enumerate(reshape_args): diff --git a/colossalai/autochunk/trace_index.py b/colossalai/autochunk/trace_indice.py similarity index 99% rename from colossalai/autochunk/trace_index.py rename to colossalai/autochunk/trace_indice.py index 1e8969d87..9a04c2a0d 100644 --- a/colossalai/autochunk/trace_index.py +++ b/colossalai/autochunk/trace_indice.py @@ -6,7 +6,7 @@ from .utils import ( ) -class TraceIndex(object): +class TraceIndice(object): def __init__(self, node_list) -> None: self.node_list = node_list self.idx_trace_list = self._init_idx_trace_list()