mirror of https://github.com/hpcaitech/ColossalAI
rename
parent
c3d72f7db9
commit
da4076846d
|
@ -17,7 +17,7 @@ from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_a
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
|
||||||
from .chunk_region_search import ChunkRegionSearch
|
from .search_chunk import SearchChunk
|
||||||
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
|
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
|
||||||
|
|
||||||
CODEGEN_AVAILABLE = True
|
CODEGEN_AVAILABLE = True
|
||||||
|
@ -103,7 +103,7 @@ def emit_code_with_chunk(
|
||||||
nodes,
|
nodes,
|
||||||
emit_node_func,
|
emit_node_func,
|
||||||
delete_unused_value_func,
|
delete_unused_value_func,
|
||||||
chunk_region_search: ChunkRegionSearch,
|
chunk_region_search: SearchChunk,
|
||||||
chunk_infos,
|
chunk_infos,
|
||||||
):
|
):
|
||||||
"""Emit code with nested activation checkpoint
|
"""Emit code with nested activation checkpoint
|
||||||
|
@ -220,7 +220,7 @@ if CODEGEN_AVAILABLE:
|
||||||
self.max_memory = max_memory
|
self.max_memory = max_memory
|
||||||
self.meta_node = list(meta_graph.graph.nodes)
|
self.meta_node = list(meta_graph.graph.nodes)
|
||||||
# find the chunk regions
|
# find the chunk regions
|
||||||
self.chunk_region_search = ChunkRegionSearch(
|
self.chunk_region_search = SearchChunk(
|
||||||
meta_graph, max_memory, print_mem
|
meta_graph, max_memory, print_mem
|
||||||
)
|
)
|
||||||
self.chunk_infos = self.chunk_region_search.search_region()
|
self.chunk_infos = self.chunk_region_search.search_region()
|
||||||
|
|
|
@ -6,7 +6,6 @@ from torch.fx.node import Node, map_arg
|
||||||
|
|
||||||
from colossalai.fx.profiler import activation_size, parameter_size
|
from colossalai.fx.profiler import activation_size, parameter_size
|
||||||
|
|
||||||
from .index_tracer import IndexTracer
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
delete_free_var_from_last_use,
|
delete_free_var_from_last_use,
|
||||||
find_idx_by_name,
|
find_idx_by_name,
|
||||||
|
@ -15,7 +14,7 @@ from .utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MemoryEstimator(object):
|
class EstimateMemory(object):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
from .chunk_selector import ChunkSelector
|
from .select_chunk import SelectChunk
|
||||||
from .index_tracer import IndexTracer, ReorderGraph
|
from .trace_index import TraceIndex, ReorderGraph
|
||||||
from .memory_estiamtor import MemoryEstimator
|
from .estiamte_memory import EstimateMemory
|
||||||
from .utils import (
|
from .utils import (
|
||||||
get_node_shape,
|
get_node_shape,
|
||||||
is_non_compute_node,
|
is_non_compute_node,
|
||||||
|
@ -10,15 +10,15 @@ from .utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChunkRegionSearch(object):
|
class SearchChunk(object):
|
||||||
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
|
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
|
||||||
self.gm = gm
|
self.gm = gm
|
||||||
self.print_mem = print_mem
|
self.print_mem = print_mem
|
||||||
self.index_tracer = IndexTracer(list(gm.graph.nodes))
|
self.index_tracer = TraceIndex(list(gm.graph.nodes))
|
||||||
self.index_tracer.trace_index()
|
self.index_tracer.trace_index()
|
||||||
self.reorder_graph = ReorderGraph(self.index_tracer)
|
self.reorder_graph = ReorderGraph(self.index_tracer)
|
||||||
self.memory_estimator = MemoryEstimator()
|
self.memory_estimator = EstimateMemory()
|
||||||
self.chunk_selector = ChunkSelector(
|
self.chunk_selector = SelectChunk(
|
||||||
self.index_tracer, self.memory_estimator, self.reorder_graph, max_memory=max_memory
|
self.index_tracer, self.memory_estimator, self.reorder_graph, max_memory=max_memory
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
from .index_tracer import IndexTracer, ReorderGraph
|
from .trace_index import TraceIndex, ReorderGraph
|
||||||
from .memory_estiamtor import MemoryEstimator
|
from .estiamte_memory import EstimateMemory
|
||||||
from .utils import is_non_compute_node
|
from .utils import is_non_compute_node
|
||||||
|
|
||||||
|
|
||||||
class ChunkSelector(object):
|
class SelectChunk(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
index_tracer: IndexTracer,
|
index_tracer: TraceIndex,
|
||||||
memory_estimator: MemoryEstimator,
|
memory_estimator: EstimateMemory,
|
||||||
reorder_graph: ReorderGraph,
|
reorder_graph: ReorderGraph,
|
||||||
max_memory=None,
|
max_memory=None,
|
||||||
):
|
):
|
|
@ -10,7 +10,7 @@ from .utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class IndexTracer(object):
|
class TraceIndex(object):
|
||||||
def __init__(self, node_list) -> None:
|
def __init__(self, node_list) -> None:
|
||||||
self.node_list = node_list
|
self.node_list = node_list
|
||||||
self.idx_trace_list = self._init_idx_trace_list()
|
self.idx_trace_list = self._init_idx_trace_list()
|
||||||
|
@ -982,7 +982,7 @@ class IndexTracer(object):
|
||||||
|
|
||||||
|
|
||||||
class ReorderGraph(object):
|
class ReorderGraph(object):
|
||||||
def __init__(self, index_tracer: IndexTracer) -> None:
|
def __init__(self, index_tracer: TraceIndex) -> None:
|
||||||
self.index_tracer = index_tracer
|
self.index_tracer = index_tracer
|
||||||
self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))}
|
self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))}
|
||||||
|
|
|
@ -104,7 +104,7 @@ def benchmark_evoformer():
|
||||||
model = evoformer_base().cuda()
|
model = evoformer_base().cuda()
|
||||||
|
|
||||||
# build autochunk model
|
# build autochunk model
|
||||||
# max_memory = 10000 # MB fit memory mode
|
# max_memory = 1000 # MB fit memory mode
|
||||||
max_memory = None # min memory mode
|
max_memory = None # min memory mode
|
||||||
autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair)
|
autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue