pull/2364/head
oahzxl 2023-01-06 17:09:37 +08:00
parent c3d72f7db9
commit da4076846d
6 changed files with 19 additions and 20 deletions

View File

@ -17,7 +17,7 @@ from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_a
import colossalai import colossalai
from .chunk_region_search import ChunkRegionSearch from .search_chunk import SearchChunk
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
CODEGEN_AVAILABLE = True CODEGEN_AVAILABLE = True
@ -103,7 +103,7 @@ def emit_code_with_chunk(
nodes, nodes,
emit_node_func, emit_node_func,
delete_unused_value_func, delete_unused_value_func,
chunk_region_search: ChunkRegionSearch, chunk_region_search: SearchChunk,
chunk_infos, chunk_infos,
): ):
"""Emit code with nested activation checkpoint """Emit code with nested activation checkpoint
@ -220,7 +220,7 @@ if CODEGEN_AVAILABLE:
self.max_memory = max_memory self.max_memory = max_memory
self.meta_node = list(meta_graph.graph.nodes) self.meta_node = list(meta_graph.graph.nodes)
# find the chunk regions # find the chunk regions
self.chunk_region_search = ChunkRegionSearch( self.chunk_region_search = SearchChunk(
meta_graph, max_memory, print_mem meta_graph, max_memory, print_mem
) )
self.chunk_infos = self.chunk_region_search.search_region() self.chunk_infos = self.chunk_region_search.search_region()

View File

@ -6,7 +6,6 @@ from torch.fx.node import Node, map_arg
from colossalai.fx.profiler import activation_size, parameter_size from colossalai.fx.profiler import activation_size, parameter_size
from .index_tracer import IndexTracer
from .utils import ( from .utils import (
delete_free_var_from_last_use, delete_free_var_from_last_use,
find_idx_by_name, find_idx_by_name,
@ -15,7 +14,7 @@ from .utils import (
) )
class MemoryEstimator(object): class EstimateMemory(object):
def __init__(self) -> None: def __init__(self) -> None:
pass pass

View File

@ -1,8 +1,8 @@
import copy import copy
from .chunk_selector import ChunkSelector from .select_chunk import SelectChunk
from .index_tracer import IndexTracer, ReorderGraph from .trace_index import TraceIndex, ReorderGraph
from .memory_estiamtor import MemoryEstimator from .estiamte_memory import EstimateMemory
from .utils import ( from .utils import (
get_node_shape, get_node_shape,
is_non_compute_node, is_non_compute_node,
@ -10,15 +10,15 @@ from .utils import (
) )
class ChunkRegionSearch(object): class SearchChunk(object):
def __init__(self, gm, max_memory=None, print_mem=False) -> None: def __init__(self, gm, max_memory=None, print_mem=False) -> None:
self.gm = gm self.gm = gm
self.print_mem = print_mem self.print_mem = print_mem
self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer = TraceIndex(list(gm.graph.nodes))
self.index_tracer.trace_index() self.index_tracer.trace_index()
self.reorder_graph = ReorderGraph(self.index_tracer) self.reorder_graph = ReorderGraph(self.index_tracer)
self.memory_estimator = MemoryEstimator() self.memory_estimator = EstimateMemory()
self.chunk_selector = ChunkSelector( self.chunk_selector = SelectChunk(
self.index_tracer, self.memory_estimator, self.reorder_graph, max_memory=max_memory self.index_tracer, self.memory_estimator, self.reorder_graph, max_memory=max_memory
) )

View File

@ -1,13 +1,13 @@
from .index_tracer import IndexTracer, ReorderGraph from .trace_index import TraceIndex, ReorderGraph
from .memory_estiamtor import MemoryEstimator from .estiamte_memory import EstimateMemory
from .utils import is_non_compute_node from .utils import is_non_compute_node
class ChunkSelector(object): class SelectChunk(object):
def __init__( def __init__(
self, self,
index_tracer: IndexTracer, index_tracer: TraceIndex,
memory_estimator: MemoryEstimator, memory_estimator: EstimateMemory,
reorder_graph: ReorderGraph, reorder_graph: ReorderGraph,
max_memory=None, max_memory=None,
): ):

View File

@ -10,7 +10,7 @@ from .utils import (
) )
class IndexTracer(object): class TraceIndex(object):
def __init__(self, node_list) -> None: def __init__(self, node_list) -> None:
self.node_list = node_list self.node_list = node_list
self.idx_trace_list = self._init_idx_trace_list() self.idx_trace_list = self._init_idx_trace_list()
@ -982,7 +982,7 @@ class IndexTracer(object):
class ReorderGraph(object): class ReorderGraph(object):
def __init__(self, index_tracer: IndexTracer) -> None: def __init__(self, index_tracer: TraceIndex) -> None:
self.index_tracer = index_tracer self.index_tracer = index_tracer
self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))} self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))}

View File

@ -104,7 +104,7 @@ def benchmark_evoformer():
model = evoformer_base().cuda() model = evoformer_base().cuda()
# build autochunk model # build autochunk model
# max_memory = 10000 # MB fit memory mode # max_memory = 1000 # MB fit memory mode
max_memory = None # min memory mode max_memory = None # min memory mode
autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair) autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair)