diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 39728cb79..891753faa 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -103,7 +103,7 @@ def emit_code_with_chunk( nodes, emit_node_func, delete_unused_value_func, - chunk_region_search: SearchChunk, + search_chunk: SearchChunk, chunk_infos, ): """Emit code with nested activation checkpoint @@ -133,7 +133,7 @@ def emit_code_with_chunk( chunk_outputs = [i["outputs"][0] for i in chunk_infos] chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos] - node_list = chunk_region_search.reorder_graph.reorder_node_list(node_list) + node_list = search_chunk.reorder_graph.reorder_node_list(node_list) node_idx = 0 region_idx = 0 within_chunk_region = False @@ -167,7 +167,7 @@ def emit_code_with_chunk( ) # ones like if "ones_like" in node.name: - meta_node = chunk_region_search.trace_index.node_list[node_idx] + meta_node = search_chunk.trace_index.node_list[node_idx] chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][ "chunk_dim" ] @@ -220,10 +220,8 @@ if CODEGEN_AVAILABLE: self.max_memory = max_memory self.meta_node = list(meta_graph.graph.nodes) # find the chunk regions - self.chunk_region_search = SearchChunk( - meta_graph, max_memory, print_mem - ) - self.chunk_infos = self.chunk_region_search.search_region() + self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem) + self.chunk_infos = self.search_chunk.search_region() def _gen_python_code( self, nodes, root_module: str, namespace: _Namespace @@ -458,7 +456,7 @@ if CODEGEN_AVAILABLE: nodes, emit_node, delete_unused_values, - self.chunk_region_search, + self.search_chunk, self.chunk_infos, ) diff --git a/colossalai/autochunk/reorder_graph.py b/colossalai/autochunk/reorder_graph.py index 7b9f4a20d..bf4420eac 100644 --- a/colossalai/autochunk/reorder_graph.py +++ b/colossalai/autochunk/reorder_graph.py @@ -3,28 +3,31 @@ from .utils import find_idx_by_name class ReorderGraph(object): - def __init__(self, index_tracer: TraceIndex) -> None: - self.index_tracer = index_tracer - self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))} + def __init__(self, trace_index: TraceIndex) -> None: + self.trace_index = trace_index + self.all_reorder_map = { + i: i for i in range(len(self.trace_index.idx_trace_list)) + } def _get_reorder_map(self, chunk_info): - reorder_map = {i: i for i in range(len(self.index_tracer.node_list))} + reorder_map = {i: i for i in range(len(self.trace_index.node_list))} chunk_region_start = chunk_info["region"][0] chunk_region_end = chunk_info["region"][1] chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] chunk_prepose_nodes_idx = [ - find_idx_by_name(i.name, self.index_tracer.node_list) for i in chunk_prepose_nodes + find_idx_by_name(i.name, self.trace_index.node_list) + for i in chunk_prepose_nodes ] # put prepose nodes ahead for idx, n in enumerate(chunk_prepose_nodes): n_idx = chunk_prepose_nodes_idx[idx] reorder_map[n_idx] = chunk_region_start + idx # put other nodes after prepose nodes - for n in self.index_tracer.node_list[chunk_region_start : chunk_region_end + 1]: + for n in self.trace_index.node_list[chunk_region_start : chunk_region_end + 1]: if n in chunk_prepose_nodes: continue - n_idx = find_idx_by_name(n.name, self.index_tracer.node_list) + n_idx = find_idx_by_name(n.name, self.trace_index.node_list) pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) reorder_map[n_idx] = n_idx + pos @@ -50,25 +53,25 @@ class ReorderGraph(object): self.all_reorder_map[origin_idx] = reorder_map[map_idx] def _reorder_self_node_list(self, reorder_map): - new_node_list = [None for _ in range(len(self.index_tracer.node_list))] + new_node_list = [None for _ in range(len(self.trace_index.node_list))] for old_idx, new_idx in reorder_map.items(): - new_node_list[new_idx] = self.index_tracer.node_list[old_idx] - self.index_tracer.node_list = new_node_list + new_node_list[new_idx] = self.trace_index.node_list[old_idx] + self.trace_index.node_list = new_node_list def _reorder_idx_trace(self, reorder_map): # reorder list - new_idx_trace_list = [None for _ in range(len(self.index_tracer.idx_trace_list))] + new_idx_trace_list = [None for _ in range(len(self.trace_index.idx_trace_list))] for old_idx, new_idx in reorder_map.items(): - new_idx_trace_list[new_idx] = self.index_tracer.idx_trace_list[old_idx] - self.index_tracer.idx_trace_list = new_idx_trace_list + new_idx_trace_list[new_idx] = self.trace_index.idx_trace_list[old_idx] + self.trace_index.idx_trace_list = new_idx_trace_list # update compute - for idx_trace in self.index_tracer.idx_trace_list: + for idx_trace in self.trace_index.idx_trace_list: compute = idx_trace["compute"] for dim_compute in compute: for idx, i in enumerate(dim_compute): dim_compute[idx] = reorder_map[i] # update source - for idx_trace in self.index_tracer.idx_trace_list: + for idx_trace in self.trace_index.idx_trace_list: source = idx_trace["source"] for dim_idx, dim_source in enumerate(source): new_dim_source = {} diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 030b13bdb..e2c8de74e 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -1,10 +1,10 @@ import copy -from .select_chunk import SelectChunk -from .trace_index import TraceIndex -from .reorder_graph import ReorderGraph from .estiamte_memory import EstimateMemory +from .reorder_graph import ReorderGraph +from .select_chunk import SelectChunk from .trace_flow import TraceFlow +from .trace_index import TraceIndex from .utils import ( get_node_shape, is_non_compute_node, @@ -22,7 +22,10 @@ class SearchChunk(object): self.reorder_graph = ReorderGraph(self.trace_index) self.estimate_memory = EstimateMemory() self.select_chunk = SelectChunk( - self.trace_index, self.estimate_memory, self.reorder_graph, max_memory=max_memory + self.trace_index, + self.estimate_memory, + self.reorder_graph, + max_memory=max_memory, ) def _find_peak_node(self, mem_peak): diff --git a/colossalai/autochunk/select_chunk.py b/colossalai/autochunk/select_chunk.py index 30f4226f5..bdc64528e 100644 --- a/colossalai/autochunk/select_chunk.py +++ b/colossalai/autochunk/select_chunk.py @@ -1,19 +1,19 @@ -from .trace_index import TraceIndex -from .reorder_graph import ReorderGraph from .estiamte_memory import EstimateMemory +from .reorder_graph import ReorderGraph +from .trace_index import TraceIndex from .utils import is_non_compute_node class SelectChunk(object): def __init__( self, - index_tracer: TraceIndex, - memory_estimator: EstimateMemory, + trace_index: TraceIndex, + estimate_memory: EstimateMemory, reorder_graph: ReorderGraph, max_memory=None, ): - self.index_tracer = index_tracer - self.memory_estimator = memory_estimator + self.index_tracer = trace_index + self.memory_estimator = estimate_memory self.reorder_graph = reorder_graph if max_memory is not None: self.stratge = "fit_memory" diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py index f372fa913..7139e7e04 100644 --- a/colossalai/autochunk/trace_flow.py +++ b/colossalai/autochunk/trace_flow.py @@ -81,7 +81,9 @@ class TraceFlow(object): input_dim_after_node = {} for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): - inherit_dim = self._find_inherit_dim(input_node, v, self.trace_index.node_list[k]) + inherit_dim = self._find_inherit_dim( + input_node, v, self.trace_index.node_list[k] + ) if inherit_dim: input_dim_after_node[k] = inherit_dim @@ -217,7 +219,9 @@ class TraceFlow(object): for arg in arg_list: if not ( start_idx - <= find_idx_by_name(arg.name, self.trace_index.node_list) + <= find_idx_by_name( + arg.name, self.trace_index.node_list + ) < end_idx ): continue @@ -255,7 +259,9 @@ class TraceFlow(object): if start_idx <= user_idx <= end_idx: chunk_dim = all_node_info[user]["chunk_dim"] if chunk_dim is not None: - user_source = self.trace_index._find_source_trace_from_node(user)[chunk_dim] + user_source = self.trace_index._find_source_trace_from_node( + user + )[chunk_dim] if input_node_idx in user_source: input_dict[user_idx] = user_source[input_node_idx] else: