mirror of https://github.com/hpcaitech/ColossalAI
seperate reorder
parent
6685a9d022
commit
c3d72f7db9
|
@ -103,7 +103,7 @@ def emit_code_with_chunk(
|
||||||
nodes,
|
nodes,
|
||||||
emit_node_func,
|
emit_node_func,
|
||||||
delete_unused_value_func,
|
delete_unused_value_func,
|
||||||
chunk_region_search,
|
chunk_region_search: ChunkRegionSearch,
|
||||||
chunk_infos,
|
chunk_infos,
|
||||||
):
|
):
|
||||||
"""Emit code with nested activation checkpoint
|
"""Emit code with nested activation checkpoint
|
||||||
|
@ -133,7 +133,7 @@ def emit_code_with_chunk(
|
||||||
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
||||||
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
|
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
|
||||||
|
|
||||||
node_list = chunk_region_search.index_tracer.reorder_node_list(node_list)
|
node_list = chunk_region_search.reorder_graph.reorder_node_list(node_list)
|
||||||
node_idx = 0
|
node_idx = 0
|
||||||
region_idx = 0
|
region_idx = 0
|
||||||
within_chunk_region = False
|
within_chunk_region = False
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
from .chunk_selector import ChunkSelector
|
from .chunk_selector import ChunkSelector
|
||||||
from .index_tracer import IndexTracer
|
from .index_tracer import IndexTracer, ReorderGraph
|
||||||
from .memory_estiamtor import MemoryEstimator
|
from .memory_estiamtor import MemoryEstimator
|
||||||
from .utils import (
|
from .utils import (
|
||||||
get_node_shape,
|
get_node_shape,
|
||||||
|
@ -16,9 +16,10 @@ class ChunkRegionSearch(object):
|
||||||
self.print_mem = print_mem
|
self.print_mem = print_mem
|
||||||
self.index_tracer = IndexTracer(list(gm.graph.nodes))
|
self.index_tracer = IndexTracer(list(gm.graph.nodes))
|
||||||
self.index_tracer.trace_index()
|
self.index_tracer.trace_index()
|
||||||
|
self.reorder_graph = ReorderGraph(self.index_tracer)
|
||||||
self.memory_estimator = MemoryEstimator()
|
self.memory_estimator = MemoryEstimator()
|
||||||
self.chunk_selector = ChunkSelector(
|
self.chunk_selector = ChunkSelector(
|
||||||
self.index_tracer, self.memory_estimator, max_memory=max_memory
|
self.index_tracer, self.memory_estimator, self.reorder_graph, max_memory=max_memory
|
||||||
)
|
)
|
||||||
|
|
||||||
def _find_peak_node(self, mem_peak):
|
def _find_peak_node(self, mem_peak):
|
||||||
|
@ -175,7 +176,7 @@ class ChunkRegionSearch(object):
|
||||||
best_chunk_region = self.chunk_selector._select_best_chunk_region(
|
best_chunk_region = self.chunk_selector._select_best_chunk_region(
|
||||||
possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak
|
possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak
|
||||||
)
|
)
|
||||||
best_chunk_region = self.index_tracer.reorder_all(best_chunk_region)
|
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
|
||||||
return best_chunk_region
|
return best_chunk_region
|
||||||
|
|
||||||
def _stop_search(self, init_mem_peak, mem_peak):
|
def _stop_search(self, init_mem_peak, mem_peak):
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from .index_tracer import IndexTracer
|
from .index_tracer import IndexTracer, ReorderGraph
|
||||||
from .memory_estiamtor import MemoryEstimator
|
from .memory_estiamtor import MemoryEstimator
|
||||||
from .utils import is_non_compute_node
|
from .utils import is_non_compute_node
|
||||||
|
|
||||||
|
@ -8,10 +8,12 @@ class ChunkSelector(object):
|
||||||
self,
|
self,
|
||||||
index_tracer: IndexTracer,
|
index_tracer: IndexTracer,
|
||||||
memory_estimator: MemoryEstimator,
|
memory_estimator: MemoryEstimator,
|
||||||
|
reorder_graph: ReorderGraph,
|
||||||
max_memory=None,
|
max_memory=None,
|
||||||
):
|
):
|
||||||
self.index_tracer = index_tracer
|
self.index_tracer = index_tracer
|
||||||
self.memory_estimator = memory_estimator
|
self.memory_estimator = memory_estimator
|
||||||
|
self.reorder_graph = reorder_graph
|
||||||
if max_memory is not None:
|
if max_memory is not None:
|
||||||
self.stratge = "fit_memory"
|
self.stratge = "fit_memory"
|
||||||
self.max_memory = max_memory # MB
|
self.max_memory = max_memory # MB
|
||||||
|
@ -64,7 +66,7 @@ class ChunkSelector(object):
|
||||||
regions_dict = []
|
regions_dict = []
|
||||||
for region in possible_chunk_regions:
|
for region in possible_chunk_regions:
|
||||||
cur_region = region.copy()
|
cur_region = region.copy()
|
||||||
cur_node_list, cur_region = self.index_tracer.tmp_reorder(
|
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
|
||||||
self.index_tracer.node_list, cur_region
|
self.index_tracer.node_list, cur_region
|
||||||
)
|
)
|
||||||
cur_chunk_infos = chunk_infos + [cur_region]
|
cur_chunk_infos = chunk_infos + [cur_region]
|
||||||
|
@ -174,7 +176,7 @@ class ChunkSelector(object):
|
||||||
regions_dict = []
|
regions_dict = []
|
||||||
for region in possible_chunk_regions:
|
for region in possible_chunk_regions:
|
||||||
cur_region = region.copy()
|
cur_region = region.copy()
|
||||||
cur_node_list, cur_region = self.index_tracer.tmp_reorder(
|
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
|
||||||
self.index_tracer.node_list, cur_region
|
self.index_tracer.node_list, cur_region
|
||||||
)
|
)
|
||||||
cur_chunk_infos = chunk_infos + [cur_region]
|
cur_chunk_infos = chunk_infos + [cur_region]
|
||||||
|
|
|
@ -17,7 +17,6 @@ class IndexTracer(object):
|
||||||
self.idx_trace_equal = []
|
self.idx_trace_equal = []
|
||||||
self.idx_view_list = {}
|
self.idx_view_list = {}
|
||||||
self.idx_count = -1
|
self.idx_count = -1
|
||||||
self.all_reorder_map = {i: i for i in range(len(self.idx_trace_list))}
|
|
||||||
|
|
||||||
def _init_idx_trace_list(self):
|
def _init_idx_trace_list(self):
|
||||||
idx_trace_list = []
|
idx_trace_list = []
|
||||||
|
@ -981,24 +980,30 @@ class IndexTracer(object):
|
||||||
chunk_info["reshape_size"] = reshape_size
|
chunk_info["reshape_size"] = reshape_size
|
||||||
return chunk_info
|
return chunk_info
|
||||||
|
|
||||||
|
|
||||||
|
class ReorderGraph(object):
|
||||||
|
def __init__(self, index_tracer: IndexTracer) -> None:
|
||||||
|
self.index_tracer = index_tracer
|
||||||
|
self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))}
|
||||||
|
|
||||||
def _get_reorder_map(self, chunk_info):
|
def _get_reorder_map(self, chunk_info):
|
||||||
reorder_map = {i: i for i in range(len(self.node_list))}
|
reorder_map = {i: i for i in range(len(self.index_tracer.node_list))}
|
||||||
|
|
||||||
chunk_region_start = chunk_info["region"][0]
|
chunk_region_start = chunk_info["region"][0]
|
||||||
chunk_region_end = chunk_info["region"][1]
|
chunk_region_end = chunk_info["region"][1]
|
||||||
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
|
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
|
||||||
chunk_prepose_nodes_idx = [
|
chunk_prepose_nodes_idx = [
|
||||||
find_idx_by_name(i.name, self.node_list) for i in chunk_prepose_nodes
|
find_idx_by_name(i.name, self.index_tracer.node_list) for i in chunk_prepose_nodes
|
||||||
]
|
]
|
||||||
# put prepose nodes ahead
|
# put prepose nodes ahead
|
||||||
for idx, n in enumerate(chunk_prepose_nodes):
|
for idx, n in enumerate(chunk_prepose_nodes):
|
||||||
n_idx = chunk_prepose_nodes_idx[idx]
|
n_idx = chunk_prepose_nodes_idx[idx]
|
||||||
reorder_map[n_idx] = chunk_region_start + idx
|
reorder_map[n_idx] = chunk_region_start + idx
|
||||||
# put other nodes after prepose nodes
|
# put other nodes after prepose nodes
|
||||||
for n in self.node_list[chunk_region_start : chunk_region_end + 1]:
|
for n in self.index_tracer.node_list[chunk_region_start : chunk_region_end + 1]:
|
||||||
if n in chunk_prepose_nodes:
|
if n in chunk_prepose_nodes:
|
||||||
continue
|
continue
|
||||||
n_idx = find_idx_by_name(n.name, self.node_list)
|
n_idx = find_idx_by_name(n.name, self.index_tracer.node_list)
|
||||||
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
|
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
|
||||||
reorder_map[n_idx] = n_idx + pos
|
reorder_map[n_idx] = n_idx + pos
|
||||||
|
|
||||||
|
@ -1024,25 +1029,25 @@ class IndexTracer(object):
|
||||||
self.all_reorder_map[origin_idx] = reorder_map[map_idx]
|
self.all_reorder_map[origin_idx] = reorder_map[map_idx]
|
||||||
|
|
||||||
def _reorder_self_node_list(self, reorder_map):
|
def _reorder_self_node_list(self, reorder_map):
|
||||||
new_node_list = [None for _ in range(len(self.node_list))]
|
new_node_list = [None for _ in range(len(self.index_tracer.node_list))]
|
||||||
for old_idx, new_idx in reorder_map.items():
|
for old_idx, new_idx in reorder_map.items():
|
||||||
new_node_list[new_idx] = self.node_list[old_idx]
|
new_node_list[new_idx] = self.index_tracer.node_list[old_idx]
|
||||||
self.node_list = new_node_list
|
self.index_tracer.node_list = new_node_list
|
||||||
|
|
||||||
def _reorder_idx_trace(self, reorder_map):
|
def _reorder_idx_trace(self, reorder_map):
|
||||||
# reorder list
|
# reorder list
|
||||||
new_idx_trace_list = [None for _ in range(len(self.idx_trace_list))]
|
new_idx_trace_list = [None for _ in range(len(self.index_tracer.idx_trace_list))]
|
||||||
for old_idx, new_idx in reorder_map.items():
|
for old_idx, new_idx in reorder_map.items():
|
||||||
new_idx_trace_list[new_idx] = self.idx_trace_list[old_idx]
|
new_idx_trace_list[new_idx] = self.index_tracer.idx_trace_list[old_idx]
|
||||||
self.idx_trace_list = new_idx_trace_list
|
self.index_tracer.idx_trace_list = new_idx_trace_list
|
||||||
# update compute
|
# update compute
|
||||||
for idx_trace in self.idx_trace_list:
|
for idx_trace in self.index_tracer.idx_trace_list:
|
||||||
compute = idx_trace["compute"]
|
compute = idx_trace["compute"]
|
||||||
for dim_compute in compute:
|
for dim_compute in compute:
|
||||||
for idx, i in enumerate(dim_compute):
|
for idx, i in enumerate(dim_compute):
|
||||||
dim_compute[idx] = reorder_map[i]
|
dim_compute[idx] = reorder_map[i]
|
||||||
# update source
|
# update source
|
||||||
for idx_trace in self.idx_trace_list:
|
for idx_trace in self.index_tracer.idx_trace_list:
|
||||||
source = idx_trace["source"]
|
source = idx_trace["source"]
|
||||||
for dim_idx, dim_source in enumerate(source):
|
for dim_idx, dim_source in enumerate(source):
|
||||||
new_dim_source = {}
|
new_dim_source = {}
|
||||||
|
|
Loading…
Reference in New Issue