close mem and code print

pull/2364/head
oahzxl 2023-01-06 14:19:45 +08:00
parent 1a6d2a740b
commit 8a634af2f5
3 changed files with 10 additions and 7 deletions

View File

@ -214,13 +214,13 @@ def emit_code_with_chunk(
if CODEGEN_AVAILABLE:
class AutoChunkCodeGen(CodeGen):
def __init__(self, meta_graph, max_memory=None):
def __init__(self, meta_graph, max_memory=None, print_mem=False):
super().__init__()
self.meta_graph = meta_graph
self.max_memory = max_memory
self.meta_node = list(meta_graph.graph.nodes)
# find the chunk regions
self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory)
self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory, print_mem)
self.chunk_infos = self.chunk_region_search.search_region()
def _gen_python_code(

View File

@ -6,8 +6,9 @@ from .utils import is_non_compute_node, is_non_compute_node_except_placeholder,
class ChunkRegionSearch(object):
def __init__(self, gm, max_memory=None) -> None:
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
self.gm = gm
self.print_mem = print_mem
self.index_tracer = IndexTracer(list(gm.graph.nodes))
self.index_tracer.trace_index()
self.memory_estimator = MemoryEstimator(self.index_tracer)
@ -204,8 +205,10 @@ class ChunkRegionSearch(object):
)
if self._stop_search(init_mem_peak, mem_peak):
break
self.memory_estimator.estimate_chunk_inference_mem(
self.index_tracer.node_list, chunk_infos, print_mem=True
)
if self.print_mem:
self.print_mem = False
self.memory_estimator.estimate_chunk_inference_mem(
self.index_tracer.node_list, chunk_infos, print_mem=True
)
return chunk_infos

View File

@ -64,7 +64,7 @@ def _build_autochunk(model, max_memory, node, pair):
)
# set code_gen
codegen = AutoChunkCodeGen(gm_prop, max_memory)
codegen = AutoChunkCodeGen(gm_prop, max_memory, print_mem=False)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph)
gm.recompile()