pull/2364/head
oahzxl 2023-01-06 17:09:37 +08:00
parent c3d72f7db9
commit da4076846d
6 changed files with 19 additions and 20 deletions

View File

@ -17,7 +17,7 @@ from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_a
import colossalai
from .chunk_region_search import ChunkRegionSearch
from .search_chunk import SearchChunk
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
CODEGEN_AVAILABLE = True
@ -103,7 +103,7 @@ def emit_code_with_chunk(
nodes,
emit_node_func,
delete_unused_value_func,
chunk_region_search: ChunkRegionSearch,
chunk_region_search: SearchChunk,
chunk_infos,
):
"""Emit code with nested activation checkpoint
@ -220,7 +220,7 @@ if CODEGEN_AVAILABLE:
self.max_memory = max_memory
self.meta_node = list(meta_graph.graph.nodes)
# find the chunk regions
self.chunk_region_search = ChunkRegionSearch(
self.chunk_region_search = SearchChunk(
meta_graph, max_memory, print_mem
)
self.chunk_infos = self.chunk_region_search.search_region()

View File

@ -6,7 +6,6 @@ from torch.fx.node import Node, map_arg
from colossalai.fx.profiler import activation_size, parameter_size
from .index_tracer import IndexTracer
from .utils import (
delete_free_var_from_last_use,
find_idx_by_name,
@ -15,7 +14,7 @@ from .utils import (
)
class MemoryEstimator(object):
class EstimateMemory(object):
def __init__(self) -> None:
pass

View File

@ -1,8 +1,8 @@
import copy
from .chunk_selector import ChunkSelector
from .index_tracer import IndexTracer, ReorderGraph
from .memory_estiamtor import MemoryEstimator
from .select_chunk import SelectChunk
from .trace_index import TraceIndex, ReorderGraph
from .estiamte_memory import EstimateMemory
from .utils import (
get_node_shape,
is_non_compute_node,
@ -10,15 +10,15 @@ from .utils import (
)
class ChunkRegionSearch(object):
class SearchChunk(object):
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
self.gm = gm
self.print_mem = print_mem
self.index_tracer = IndexTracer(list(gm.graph.nodes))
self.index_tracer = TraceIndex(list(gm.graph.nodes))
self.index_tracer.trace_index()
self.reorder_graph = ReorderGraph(self.index_tracer)
self.memory_estimator = MemoryEstimator()
self.chunk_selector = ChunkSelector(
self.memory_estimator = EstimateMemory()
self.chunk_selector = SelectChunk(
self.index_tracer, self.memory_estimator, self.reorder_graph, max_memory=max_memory
)

View File

@ -1,13 +1,13 @@
from .index_tracer import IndexTracer, ReorderGraph
from .memory_estiamtor import MemoryEstimator
from .trace_index import TraceIndex, ReorderGraph
from .estiamte_memory import EstimateMemory
from .utils import is_non_compute_node
class ChunkSelector(object):
class SelectChunk(object):
def __init__(
self,
index_tracer: IndexTracer,
memory_estimator: MemoryEstimator,
index_tracer: TraceIndex,
memory_estimator: EstimateMemory,
reorder_graph: ReorderGraph,
max_memory=None,
):

View File

@ -10,7 +10,7 @@ from .utils import (
)
class IndexTracer(object):
class TraceIndex(object):
def __init__(self, node_list) -> None:
self.node_list = node_list
self.idx_trace_list = self._init_idx_trace_list()
@ -982,7 +982,7 @@ class IndexTracer(object):
class ReorderGraph(object):
def __init__(self, index_tracer: IndexTracer) -> None:
def __init__(self, index_tracer: TraceIndex) -> None:
self.index_tracer = index_tracer
self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))}

View File

@ -104,7 +104,7 @@ def benchmark_evoformer():
model = evoformer_base().cuda()
# build autochunk model
# max_memory = 10000 # MB fit memory mode
# max_memory = 1000 # MB fit memory mode
max_memory = None # min memory mode
autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair)