mirror of https://github.com/hpcaitech/ColossalAI
code format
parent
8a634af2f5
commit
2bde9d2b7f
|
@ -220,7 +220,9 @@ if CODEGEN_AVAILABLE:
|
||||||
self.max_memory = max_memory
|
self.max_memory = max_memory
|
||||||
self.meta_node = list(meta_graph.graph.nodes)
|
self.meta_node = list(meta_graph.graph.nodes)
|
||||||
# find the chunk regions
|
# find the chunk regions
|
||||||
self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory, print_mem)
|
self.chunk_region_search = ChunkRegionSearch(
|
||||||
|
meta_graph, max_memory, print_mem
|
||||||
|
)
|
||||||
self.chunk_infos = self.chunk_region_search.search_region()
|
self.chunk_infos = self.chunk_region_search.search_region()
|
||||||
|
|
||||||
def _gen_python_code(
|
def _gen_python_code(
|
||||||
|
|
|
@ -1,8 +1,13 @@
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from .chunk_selector import ChunkSelector
|
||||||
from .index_tracer import IndexTracer
|
from .index_tracer import IndexTracer
|
||||||
from .memory_estiamtor import MemoryEstimator
|
from .memory_estiamtor import MemoryEstimator
|
||||||
from .chunk_selector import ChunkSelector
|
from .utils import (
|
||||||
import copy
|
get_node_shape,
|
||||||
from .utils import is_non_compute_node, is_non_compute_node_except_placeholder, get_node_shape
|
is_non_compute_node,
|
||||||
|
is_non_compute_node_except_placeholder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChunkRegionSearch(object):
|
class ChunkRegionSearch(object):
|
||||||
|
@ -11,7 +16,7 @@ class ChunkRegionSearch(object):
|
||||||
self.print_mem = print_mem
|
self.print_mem = print_mem
|
||||||
self.index_tracer = IndexTracer(list(gm.graph.nodes))
|
self.index_tracer = IndexTracer(list(gm.graph.nodes))
|
||||||
self.index_tracer.trace_index()
|
self.index_tracer.trace_index()
|
||||||
self.memory_estimator = MemoryEstimator(self.index_tracer)
|
self.memory_estimator = MemoryEstimator()
|
||||||
self.chunk_selector = ChunkSelector(
|
self.chunk_selector = ChunkSelector(
|
||||||
self.index_tracer, self.memory_estimator, max_memory=max_memory
|
self.index_tracer, self.memory_estimator, max_memory=max_memory
|
||||||
)
|
)
|
||||||
|
@ -211,4 +216,3 @@ class ChunkRegionSearch(object):
|
||||||
self.index_tracer.node_list, chunk_infos, print_mem=True
|
self.index_tracer.node_list, chunk_infos, print_mem=True
|
||||||
)
|
)
|
||||||
return chunk_infos
|
return chunk_infos
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from .utils import (
|
||||||
|
|
||||||
|
|
||||||
class MemoryEstimator(object):
|
class MemoryEstimator(object):
|
||||||
def __init__(self, index_tracer: IndexTracer) -> None:
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_meta_node_size(self, x):
|
def _get_meta_node_size(self, x):
|
||||||
|
|
Loading…
Reference in New Issue