mirror of https://github.com/hpcaitech/ColossalAI
rename
parent
c3d72f7db9
commit
da4076846d
|
@ -17,7 +17,7 @@ from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_a
|
|||
|
||||
import colossalai
|
||||
|
||||
from .chunk_region_search import ChunkRegionSearch
|
||||
from .search_chunk import SearchChunk
|
||||
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
|
||||
|
||||
CODEGEN_AVAILABLE = True
|
||||
|
@ -103,7 +103,7 @@ def emit_code_with_chunk(
|
|||
nodes,
|
||||
emit_node_func,
|
||||
delete_unused_value_func,
|
||||
chunk_region_search: ChunkRegionSearch,
|
||||
chunk_region_search: SearchChunk,
|
||||
chunk_infos,
|
||||
):
|
||||
"""Emit code with nested activation checkpoint
|
||||
|
@ -220,7 +220,7 @@ if CODEGEN_AVAILABLE:
|
|||
self.max_memory = max_memory
|
||||
self.meta_node = list(meta_graph.graph.nodes)
|
||||
# find the chunk regions
|
||||
self.chunk_region_search = ChunkRegionSearch(
|
||||
self.chunk_region_search = SearchChunk(
|
||||
meta_graph, max_memory, print_mem
|
||||
)
|
||||
self.chunk_infos = self.chunk_region_search.search_region()
|
||||
|
|
|
@ -6,7 +6,6 @@ from torch.fx.node import Node, map_arg
|
|||
|
||||
from colossalai.fx.profiler import activation_size, parameter_size
|
||||
|
||||
from .index_tracer import IndexTracer
|
||||
from .utils import (
|
||||
delete_free_var_from_last_use,
|
||||
find_idx_by_name,
|
||||
|
@ -15,7 +14,7 @@ from .utils import (
|
|||
)
|
||||
|
||||
|
||||
class MemoryEstimator(object):
|
||||
class EstimateMemory(object):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
import copy
|
||||
|
||||
from .chunk_selector import ChunkSelector
|
||||
from .index_tracer import IndexTracer, ReorderGraph
|
||||
from .memory_estiamtor import MemoryEstimator
|
||||
from .select_chunk import SelectChunk
|
||||
from .trace_index import TraceIndex, ReorderGraph
|
||||
from .estiamte_memory import EstimateMemory
|
||||
from .utils import (
|
||||
get_node_shape,
|
||||
is_non_compute_node,
|
||||
|
@ -10,15 +10,15 @@ from .utils import (
|
|||
)
|
||||
|
||||
|
||||
class ChunkRegionSearch(object):
|
||||
class SearchChunk(object):
|
||||
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
|
||||
self.gm = gm
|
||||
self.print_mem = print_mem
|
||||
self.index_tracer = IndexTracer(list(gm.graph.nodes))
|
||||
self.index_tracer = TraceIndex(list(gm.graph.nodes))
|
||||
self.index_tracer.trace_index()
|
||||
self.reorder_graph = ReorderGraph(self.index_tracer)
|
||||
self.memory_estimator = MemoryEstimator()
|
||||
self.chunk_selector = ChunkSelector(
|
||||
self.memory_estimator = EstimateMemory()
|
||||
self.chunk_selector = SelectChunk(
|
||||
self.index_tracer, self.memory_estimator, self.reorder_graph, max_memory=max_memory
|
||||
)
|
||||
|
|
@ -1,13 +1,13 @@
|
|||
from .index_tracer import IndexTracer, ReorderGraph
|
||||
from .memory_estiamtor import MemoryEstimator
|
||||
from .trace_index import TraceIndex, ReorderGraph
|
||||
from .estiamte_memory import EstimateMemory
|
||||
from .utils import is_non_compute_node
|
||||
|
||||
|
||||
class ChunkSelector(object):
|
||||
class SelectChunk(object):
|
||||
def __init__(
|
||||
self,
|
||||
index_tracer: IndexTracer,
|
||||
memory_estimator: MemoryEstimator,
|
||||
index_tracer: TraceIndex,
|
||||
memory_estimator: EstimateMemory,
|
||||
reorder_graph: ReorderGraph,
|
||||
max_memory=None,
|
||||
):
|
|
@ -10,7 +10,7 @@ from .utils import (
|
|||
)
|
||||
|
||||
|
||||
class IndexTracer(object):
|
||||
class TraceIndex(object):
|
||||
def __init__(self, node_list) -> None:
|
||||
self.node_list = node_list
|
||||
self.idx_trace_list = self._init_idx_trace_list()
|
||||
|
@ -982,7 +982,7 @@ class IndexTracer(object):
|
|||
|
||||
|
||||
class ReorderGraph(object):
|
||||
def __init__(self, index_tracer: IndexTracer) -> None:
|
||||
def __init__(self, index_tracer: TraceIndex) -> None:
|
||||
self.index_tracer = index_tracer
|
||||
self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))}
|
||||
|
|
@ -104,7 +104,7 @@ def benchmark_evoformer():
|
|||
model = evoformer_base().cuda()
|
||||
|
||||
# build autochunk model
|
||||
# max_memory = 10000 # MB fit memory mode
|
||||
# max_memory = 1000 # MB fit memory mode
|
||||
max_memory = None # min memory mode
|
||||
autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair)
|
||||
|
||||
|
|
Loading…
Reference in New Issue