From c3d72f7db9e2fc28e9a3aa92749f08c7a7d51e42 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 16:53:01 +0800 Subject: [PATCH] seperate reorder --- colossalai/autochunk/autochunk_codegen.py | 4 +-- colossalai/autochunk/chunk_region_search.py | 7 +++-- colossalai/autochunk/chunk_selector.py | 8 ++++-- colossalai/autochunk/index_tracer.py | 31 ++++++++++++--------- 4 files changed, 29 insertions(+), 21 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index fbd5d5e36..b4144196a 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -103,7 +103,7 @@ def emit_code_with_chunk( nodes, emit_node_func, delete_unused_value_func, - chunk_region_search, + chunk_region_search: ChunkRegionSearch, chunk_infos, ): """Emit code with nested activation checkpoint @@ -133,7 +133,7 @@ def emit_code_with_chunk( chunk_outputs = [i["outputs"][0] for i in chunk_infos] chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos] - node_list = chunk_region_search.index_tracer.reorder_node_list(node_list) + node_list = chunk_region_search.reorder_graph.reorder_node_list(node_list) node_idx = 0 region_idx = 0 within_chunk_region = False diff --git a/colossalai/autochunk/chunk_region_search.py b/colossalai/autochunk/chunk_region_search.py index 7a0e8a36c..47e2fe13c 100644 --- a/colossalai/autochunk/chunk_region_search.py +++ b/colossalai/autochunk/chunk_region_search.py @@ -1,7 +1,7 @@ import copy from .chunk_selector import ChunkSelector -from .index_tracer import IndexTracer +from .index_tracer import IndexTracer, ReorderGraph from .memory_estiamtor import MemoryEstimator from .utils import ( get_node_shape, @@ -16,9 +16,10 @@ class ChunkRegionSearch(object): self.print_mem = print_mem self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer.trace_index() + self.reorder_graph = ReorderGraph(self.index_tracer) self.memory_estimator = MemoryEstimator() self.chunk_selector = ChunkSelector( - self.index_tracer, self.memory_estimator, max_memory=max_memory + self.index_tracer, self.memory_estimator, self.reorder_graph, max_memory=max_memory ) def _find_peak_node(self, mem_peak): @@ -175,7 +176,7 @@ class ChunkRegionSearch(object): best_chunk_region = self.chunk_selector._select_best_chunk_region( possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak ) - best_chunk_region = self.index_tracer.reorder_all(best_chunk_region) + best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region) return best_chunk_region def _stop_search(self, init_mem_peak, mem_peak): diff --git a/colossalai/autochunk/chunk_selector.py b/colossalai/autochunk/chunk_selector.py index aeab66572..119ff8aaf 100644 --- a/colossalai/autochunk/chunk_selector.py +++ b/colossalai/autochunk/chunk_selector.py @@ -1,4 +1,4 @@ -from .index_tracer import IndexTracer +from .index_tracer import IndexTracer, ReorderGraph from .memory_estiamtor import MemoryEstimator from .utils import is_non_compute_node @@ -8,10 +8,12 @@ class ChunkSelector(object): self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, + reorder_graph: ReorderGraph, max_memory=None, ): self.index_tracer = index_tracer self.memory_estimator = memory_estimator + self.reorder_graph = reorder_graph if max_memory is not None: self.stratge = "fit_memory" self.max_memory = max_memory # MB @@ -64,7 +66,7 @@ class ChunkSelector(object): regions_dict = [] for region in possible_chunk_regions: cur_region = region.copy() - cur_node_list, cur_region = self.index_tracer.tmp_reorder( + cur_node_list, cur_region = self.reorder_graph.tmp_reorder( self.index_tracer.node_list, cur_region ) cur_chunk_infos = chunk_infos + [cur_region] @@ -174,7 +176,7 @@ class ChunkSelector(object): regions_dict = [] for region in possible_chunk_regions: cur_region = region.copy() - cur_node_list, cur_region = self.index_tracer.tmp_reorder( + cur_node_list, cur_region = self.reorder_graph.tmp_reorder( self.index_tracer.node_list, cur_region ) cur_chunk_infos = chunk_infos + [cur_region] diff --git a/colossalai/autochunk/index_tracer.py b/colossalai/autochunk/index_tracer.py index 202044763..8b4d3aabd 100644 --- a/colossalai/autochunk/index_tracer.py +++ b/colossalai/autochunk/index_tracer.py @@ -17,7 +17,6 @@ class IndexTracer(object): self.idx_trace_equal = [] self.idx_view_list = {} self.idx_count = -1 - self.all_reorder_map = {i: i for i in range(len(self.idx_trace_list))} def _init_idx_trace_list(self): idx_trace_list = [] @@ -981,24 +980,30 @@ class IndexTracer(object): chunk_info["reshape_size"] = reshape_size return chunk_info + +class ReorderGraph(object): + def __init__(self, index_tracer: IndexTracer) -> None: + self.index_tracer = index_tracer + self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))} + def _get_reorder_map(self, chunk_info): - reorder_map = {i: i for i in range(len(self.node_list))} + reorder_map = {i: i for i in range(len(self.index_tracer.node_list))} chunk_region_start = chunk_info["region"][0] chunk_region_end = chunk_info["region"][1] chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] chunk_prepose_nodes_idx = [ - find_idx_by_name(i.name, self.node_list) for i in chunk_prepose_nodes + find_idx_by_name(i.name, self.index_tracer.node_list) for i in chunk_prepose_nodes ] # put prepose nodes ahead for idx, n in enumerate(chunk_prepose_nodes): n_idx = chunk_prepose_nodes_idx[idx] reorder_map[n_idx] = chunk_region_start + idx # put other nodes after prepose nodes - for n in self.node_list[chunk_region_start : chunk_region_end + 1]: + for n in self.index_tracer.node_list[chunk_region_start : chunk_region_end + 1]: if n in chunk_prepose_nodes: continue - n_idx = find_idx_by_name(n.name, self.node_list) + n_idx = find_idx_by_name(n.name, self.index_tracer.node_list) pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) reorder_map[n_idx] = n_idx + pos @@ -1024,25 +1029,25 @@ class IndexTracer(object): self.all_reorder_map[origin_idx] = reorder_map[map_idx] def _reorder_self_node_list(self, reorder_map): - new_node_list = [None for _ in range(len(self.node_list))] + new_node_list = [None for _ in range(len(self.index_tracer.node_list))] for old_idx, new_idx in reorder_map.items(): - new_node_list[new_idx] = self.node_list[old_idx] - self.node_list = new_node_list + new_node_list[new_idx] = self.index_tracer.node_list[old_idx] + self.index_tracer.node_list = new_node_list def _reorder_idx_trace(self, reorder_map): # reorder list - new_idx_trace_list = [None for _ in range(len(self.idx_trace_list))] + new_idx_trace_list = [None for _ in range(len(self.index_tracer.idx_trace_list))] for old_idx, new_idx in reorder_map.items(): - new_idx_trace_list[new_idx] = self.idx_trace_list[old_idx] - self.idx_trace_list = new_idx_trace_list + new_idx_trace_list[new_idx] = self.index_tracer.idx_trace_list[old_idx] + self.index_tracer.idx_trace_list = new_idx_trace_list # update compute - for idx_trace in self.idx_trace_list: + for idx_trace in self.index_tracer.idx_trace_list: compute = idx_trace["compute"] for dim_compute in compute: for idx, i in enumerate(dim_compute): dim_compute[idx] = reorder_map[i] # update source - for idx_trace in self.idx_trace_list: + for idx_trace in self.index_tracer.idx_trace_list: source = idx_trace["source"] for dim_idx, dim_source in enumerate(source): new_dim_source = {}