mirror of https://github.com/hpcaitech/ColossalAI
code format
parent
8a634af2f5
commit
2bde9d2b7f
|
@ -220,7 +220,9 @@ if CODEGEN_AVAILABLE:
|
|||
self.max_memory = max_memory
|
||||
self.meta_node = list(meta_graph.graph.nodes)
|
||||
# find the chunk regions
|
||||
self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory, print_mem)
|
||||
self.chunk_region_search = ChunkRegionSearch(
|
||||
meta_graph, max_memory, print_mem
|
||||
)
|
||||
self.chunk_infos = self.chunk_region_search.search_region()
|
||||
|
||||
def _gen_python_code(
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
import copy
|
||||
|
||||
from .chunk_selector import ChunkSelector
|
||||
from .index_tracer import IndexTracer
|
||||
from .memory_estiamtor import MemoryEstimator
|
||||
from .chunk_selector import ChunkSelector
|
||||
import copy
|
||||
from .utils import is_non_compute_node, is_non_compute_node_except_placeholder, get_node_shape
|
||||
from .utils import (
|
||||
get_node_shape,
|
||||
is_non_compute_node,
|
||||
is_non_compute_node_except_placeholder,
|
||||
)
|
||||
|
||||
|
||||
class ChunkRegionSearch(object):
|
||||
|
@ -11,7 +16,7 @@ class ChunkRegionSearch(object):
|
|||
self.print_mem = print_mem
|
||||
self.index_tracer = IndexTracer(list(gm.graph.nodes))
|
||||
self.index_tracer.trace_index()
|
||||
self.memory_estimator = MemoryEstimator(self.index_tracer)
|
||||
self.memory_estimator = MemoryEstimator()
|
||||
self.chunk_selector = ChunkSelector(
|
||||
self.index_tracer, self.memory_estimator, max_memory=max_memory
|
||||
)
|
||||
|
@ -211,4 +216,3 @@ class ChunkRegionSearch(object):
|
|||
self.index_tracer.node_list, chunk_infos, print_mem=True
|
||||
)
|
||||
return chunk_infos
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from .utils import (
|
|||
|
||||
|
||||
class MemoryEstimator(object):
|
||||
def __init__(self, index_tracer: IndexTracer) -> None:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_meta_node_size(self, x):
|
||||
|
|
Loading…
Reference in New Issue