From 8a634af2f5510954e7a992c0ee894d22cf9e26d2 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 14:19:45 +0800 Subject: [PATCH] close mem and code print --- colossalai/autochunk/autochunk_codegen.py | 4 ++-- colossalai/autochunk/chunk_region_search.py | 11 +++++++---- tests/test_autochunk/benchmark_autochunk.py | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 58a8c3751..dcc6bba9e 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -214,13 +214,13 @@ def emit_code_with_chunk( if CODEGEN_AVAILABLE: class AutoChunkCodeGen(CodeGen): - def __init__(self, meta_graph, max_memory=None): + def __init__(self, meta_graph, max_memory=None, print_mem=False): super().__init__() self.meta_graph = meta_graph self.max_memory = max_memory self.meta_node = list(meta_graph.graph.nodes) # find the chunk regions - self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory) + self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory, print_mem) self.chunk_infos = self.chunk_region_search.search_region() def _gen_python_code( diff --git a/colossalai/autochunk/chunk_region_search.py b/colossalai/autochunk/chunk_region_search.py index 0d0825f25..76b02cade 100644 --- a/colossalai/autochunk/chunk_region_search.py +++ b/colossalai/autochunk/chunk_region_search.py @@ -6,8 +6,9 @@ from .utils import is_non_compute_node, is_non_compute_node_except_placeholder, class ChunkRegionSearch(object): - def __init__(self, gm, max_memory=None) -> None: + def __init__(self, gm, max_memory=None, print_mem=False) -> None: self.gm = gm + self.print_mem = print_mem self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer.trace_index() self.memory_estimator = MemoryEstimator(self.index_tracer) @@ -204,8 +205,10 @@ class ChunkRegionSearch(object): ) if self._stop_search(init_mem_peak, mem_peak): break - self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, chunk_infos, print_mem=True - ) + if self.print_mem: + self.print_mem = False + self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list, chunk_infos, print_mem=True + ) return chunk_infos diff --git a/tests/test_autochunk/benchmark_autochunk.py b/tests/test_autochunk/benchmark_autochunk.py index 702eb7026..9daaa364a 100644 --- a/tests/test_autochunk/benchmark_autochunk.py +++ b/tests/test_autochunk/benchmark_autochunk.py @@ -64,7 +64,7 @@ def _build_autochunk(model, max_memory, node, pair): ) # set code_gen - codegen = AutoChunkCodeGen(gm_prop, max_memory) + codegen = AutoChunkCodeGen(gm_prop, max_memory, print_mem=False) graph.set_codegen(codegen) gm = ColoGraphModule(model, graph) gm.recompile()