mirror of https://github.com/hpcaitech/ColossalAI
close mem and code print
parent
1a6d2a740b
commit
8a634af2f5
|
@ -214,13 +214,13 @@ def emit_code_with_chunk(
|
|||
if CODEGEN_AVAILABLE:
|
||||
|
||||
class AutoChunkCodeGen(CodeGen):
|
||||
def __init__(self, meta_graph, max_memory=None):
|
||||
def __init__(self, meta_graph, max_memory=None, print_mem=False):
|
||||
super().__init__()
|
||||
self.meta_graph = meta_graph
|
||||
self.max_memory = max_memory
|
||||
self.meta_node = list(meta_graph.graph.nodes)
|
||||
# find the chunk regions
|
||||
self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory)
|
||||
self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory, print_mem)
|
||||
self.chunk_infos = self.chunk_region_search.search_region()
|
||||
|
||||
def _gen_python_code(
|
||||
|
|
|
@ -6,8 +6,9 @@ from .utils import is_non_compute_node, is_non_compute_node_except_placeholder,
|
|||
|
||||
|
||||
class ChunkRegionSearch(object):
|
||||
def __init__(self, gm, max_memory=None) -> None:
|
||||
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
|
||||
self.gm = gm
|
||||
self.print_mem = print_mem
|
||||
self.index_tracer = IndexTracer(list(gm.graph.nodes))
|
||||
self.index_tracer.trace_index()
|
||||
self.memory_estimator = MemoryEstimator(self.index_tracer)
|
||||
|
@ -204,8 +205,10 @@ class ChunkRegionSearch(object):
|
|||
)
|
||||
if self._stop_search(init_mem_peak, mem_peak):
|
||||
break
|
||||
self.memory_estimator.estimate_chunk_inference_mem(
|
||||
self.index_tracer.node_list, chunk_infos, print_mem=True
|
||||
)
|
||||
if self.print_mem:
|
||||
self.print_mem = False
|
||||
self.memory_estimator.estimate_chunk_inference_mem(
|
||||
self.index_tracer.node_list, chunk_infos, print_mem=True
|
||||
)
|
||||
return chunk_infos
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ def _build_autochunk(model, max_memory, node, pair):
|
|||
)
|
||||
|
||||
# set code_gen
|
||||
codegen = AutoChunkCodeGen(gm_prop, max_memory)
|
||||
codegen = AutoChunkCodeGen(gm_prop, max_memory, print_mem=False)
|
||||
graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
|
Loading…
Reference in New Issue