code format

pull/2364/head
oahzxl 2023-01-06 14:21:49 +08:00
parent 8a634af2f5
commit 2bde9d2b7f
3 changed files with 13 additions and 7 deletions

View File

@ -220,7 +220,9 @@ if CODEGEN_AVAILABLE:
self.max_memory = max_memory
self.meta_node = list(meta_graph.graph.nodes)
# find the chunk regions
self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory, print_mem)
self.chunk_region_search = ChunkRegionSearch(
meta_graph, max_memory, print_mem
)
self.chunk_infos = self.chunk_region_search.search_region()
def _gen_python_code(

View File

@ -1,8 +1,13 @@
import copy
from .chunk_selector import ChunkSelector
from .index_tracer import IndexTracer
from .memory_estiamtor import MemoryEstimator
from .chunk_selector import ChunkSelector
import copy
from .utils import is_non_compute_node, is_non_compute_node_except_placeholder, get_node_shape
from .utils import (
get_node_shape,
is_non_compute_node,
is_non_compute_node_except_placeholder,
)
class ChunkRegionSearch(object):
@ -11,7 +16,7 @@ class ChunkRegionSearch(object):
self.print_mem = print_mem
self.index_tracer = IndexTracer(list(gm.graph.nodes))
self.index_tracer.trace_index()
self.memory_estimator = MemoryEstimator(self.index_tracer)
self.memory_estimator = MemoryEstimator()
self.chunk_selector = ChunkSelector(
self.index_tracer, self.memory_estimator, max_memory=max_memory
)
@ -211,4 +216,3 @@ class ChunkRegionSearch(object):
self.index_tracer.node_list, chunk_infos, print_mem=True
)
return chunk_infos

View File

@ -16,7 +16,7 @@ from .utils import (
class MemoryEstimator(object):
def __init__(self, index_tracer: IndexTracer) -> None:
def __init__(self) -> None:
pass
def _get_meta_node_size(self, x):