mirror of https://github.com/hpcaitech/ColossalAI
reorder nodes
parent
e0ae68e736
commit
884a228ea6
127
chunk_codegen.py
127
chunk_codegen.py
|
@ -71,6 +71,7 @@ class IndexTracer(object):
|
|||
self.idx_trace_equal = []
|
||||
self.idx_view_list = []
|
||||
self.idx_count = -1
|
||||
self.all_reorder_map = {i: i for i in range(len(self.idx_trace_list))}
|
||||
|
||||
def _init_idx_trace_list(self):
|
||||
idx_trace_list = []
|
||||
|
@ -973,6 +974,91 @@ class IndexTracer(object):
|
|||
|
||||
return chunk_info
|
||||
|
||||
def _get_reorder_map(self, chunk_info):
|
||||
reorder_map = {i: i for i in range(len(self.node_list))}
|
||||
|
||||
chunk_region_start = chunk_info["region"][0]
|
||||
chunk_region_end = chunk_info["region"][1]
|
||||
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
|
||||
chunk_prepose_nodes_idx = [
|
||||
_find_idx_by_name(i.name, self.node_list) for i in chunk_prepose_nodes
|
||||
]
|
||||
# put prepose nodes ahead
|
||||
for idx, n in enumerate(chunk_prepose_nodes):
|
||||
n_idx = chunk_prepose_nodes_idx[idx]
|
||||
reorder_map[n_idx] = chunk_region_start + idx
|
||||
# put other nodes after prepose nodes
|
||||
for n in self.node_list[chunk_region_start : chunk_region_end + 1]:
|
||||
if n in chunk_prepose_nodes:
|
||||
continue
|
||||
n_idx = _find_idx_by_name(n.name, self.node_list)
|
||||
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
|
||||
reorder_map[n_idx] = n_idx + pos
|
||||
|
||||
return reorder_map
|
||||
|
||||
def _reorder_chunk_info(self, chunk_info, reorder_map):
|
||||
# update chunk info
|
||||
chunk_info["region"] = (
|
||||
chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]),
|
||||
chunk_info["region"][1],
|
||||
)
|
||||
for idx, input_dim in enumerate(chunk_info["inputs_dim"]):
|
||||
new_input_dim = {}
|
||||
for k, v in input_dim.items():
|
||||
new_input_dim[reorder_map[k]] = v
|
||||
chunk_info["inputs_dim"][idx] = new_input_dim
|
||||
return chunk_info
|
||||
|
||||
def _update_all_reorder_map(self, reorder_map):
|
||||
for origin_idx, map_idx in self.all_reorder_map.items():
|
||||
self.all_reorder_map[origin_idx] = reorder_map[map_idx]
|
||||
|
||||
def _reorder_self_node_list(self, reorder_map):
|
||||
new_node_list = [None for _ in range(len(self.node_list))]
|
||||
for old_idx, new_idx in reorder_map.items():
|
||||
new_node_list[new_idx] = self.node_list[old_idx]
|
||||
self.node_list = new_node_list
|
||||
|
||||
def _reorder_idx_trace(self, reorder_map):
|
||||
# reorder list
|
||||
new_idx_trace_list = [None for _ in range(len(self.idx_trace_list))]
|
||||
for old_idx, new_idx in reorder_map.items():
|
||||
new_idx_trace_list[new_idx] = self.idx_trace_list[old_idx]
|
||||
self.idx_trace_list = new_idx_trace_list
|
||||
# update compute
|
||||
for idx_trace in self.idx_trace_list:
|
||||
compute = idx_trace["compute"]
|
||||
for dim_compute in compute:
|
||||
for idx, i in enumerate(dim_compute):
|
||||
dim_compute[idx] = reorder_map[i]
|
||||
# update source
|
||||
for idx_trace in self.idx_trace_list:
|
||||
source = idx_trace["source"]
|
||||
for dim_idx, dim_source in enumerate(source):
|
||||
new_dim_source = {}
|
||||
for k, v in dim_source.items():
|
||||
new_dim_source[reorder_map[k]] = v
|
||||
source[dim_idx] = new_dim_source
|
||||
|
||||
def reorder_all(self, chunk_info):
|
||||
if chunk_info is None:
|
||||
return chunk_info
|
||||
if len(chunk_info["args"]["prepose_nodes"]) == 0:
|
||||
return chunk_info
|
||||
reorder_map = self._get_reorder_map(chunk_info)
|
||||
self._update_all_reorder_map(reorder_map)
|
||||
self._reorder_idx_trace(reorder_map)
|
||||
self._reorder_self_node_list(reorder_map)
|
||||
chunk_info = self._reorder_chunk_info(chunk_info, reorder_map)
|
||||
return chunk_info
|
||||
|
||||
def reorder_node_list(self, node_list):
|
||||
new_node_list = [None for _ in range(len(node_list))]
|
||||
for old_idx, new_idx in self.all_reorder_map.items():
|
||||
new_node_list[new_idx] = node_list[old_idx]
|
||||
return new_node_list
|
||||
|
||||
|
||||
class MemoryEstimator(object):
|
||||
def __init__(self, index_tracer: IndexTracer) -> None:
|
||||
|
@ -1476,6 +1562,7 @@ class ChunkRegionSearch(object):
|
|||
best_chunk_region = self._search_best_chunk_region(
|
||||
possible_chunk_regions, chunk_regions
|
||||
)
|
||||
best_chunk_region = self.index_tracer.reorder_all(best_chunk_region)
|
||||
return best_chunk_region
|
||||
|
||||
def _stop_search(self, init_mem_peak, mem_peak):
|
||||
|
@ -1670,8 +1757,7 @@ def emit_code_with_chunk(
|
|||
chunk_outputs = [i["outputs"][0] for i in chunk_search]
|
||||
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search]
|
||||
|
||||
chunk_prepose_nodes = [i["args"]["prepose_nodes"] for i in chunk_search]
|
||||
|
||||
node_list = chunk_region_search.index_tracer.reorder_node_list(node_list)
|
||||
node_idx = 0
|
||||
region_idx = 0
|
||||
within_chunk_region = False
|
||||
|
@ -1682,12 +1768,6 @@ def emit_code_with_chunk(
|
|||
if node_idx in chunk_starts:
|
||||
within_chunk_region = True
|
||||
region_idx = chunk_starts.index(node_idx)
|
||||
# add prepose nodes
|
||||
for i in chunk_prepose_nodes[region_idx]:
|
||||
prepose_node = node_list[_find_idx_by_name(i.name, node_list)]
|
||||
emit_node_func(prepose_node, body)
|
||||
delete_unused_value_func(prepose_node, body, chunk_inputs_names)
|
||||
# add for loop
|
||||
body.append(
|
||||
_gen_loop_start(
|
||||
chunk_inputs[region_idx],
|
||||
|
@ -1697,24 +1777,19 @@ def emit_code_with_chunk(
|
|||
)
|
||||
|
||||
if within_chunk_region:
|
||||
if any(node.name == i.name for i in chunk_prepose_nodes[region_idx]):
|
||||
pass
|
||||
else:
|
||||
emit_node_func(node, body)
|
||||
# replace input var with chunk var
|
||||
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
||||
for idx, dim in chunk_inputs_dim[region_idx][
|
||||
input_node_idx
|
||||
].items():
|
||||
if idx == node_idx:
|
||||
chunk_slice = _gen_chunk_slice_dim(
|
||||
dim, "chunk_idx", _get_node_shape(input_node)
|
||||
)
|
||||
body[-1] = _replace_name(
|
||||
body[-1], input_node.name, input_node.name + chunk_slice
|
||||
)
|
||||
body[-1] = " " + body[-1]
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
emit_node_func(node, body)
|
||||
# replace input var with chunk var
|
||||
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
||||
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
||||
if idx == node_idx:
|
||||
chunk_slice = _gen_chunk_slice_dim(
|
||||
dim, "chunk_idx", _get_node_shape(input_node)
|
||||
)
|
||||
body[-1] = _replace_name(
|
||||
body[-1], input_node.name, input_node.name + chunk_slice
|
||||
)
|
||||
body[-1] = " " + body[-1]
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
else:
|
||||
emit_node_func(node, body)
|
||||
if node_idx not in chunk_inputs:
|
||||
|
|
Loading…
Reference in New Issue