ColossalAI/colossalai/autochunk/reorder_graph.py

118 lines
4.7 KiB
Python
Raw Normal View History

2023-01-09 09:25:13 +00:00
from .trace_indice import TraceIndice
2023-01-06 09:13:18 +00:00
from .utils import find_idx_by_name
class ReorderGraph(object):
2023-01-10 01:59:47 +00:00
"""
Reorder node list and indice trace list
"""
2023-01-09 09:25:13 +00:00
def __init__(self, trace_indice: TraceIndice) -> None:
self.trace_indice = trace_indice
2023-01-06 09:31:59 +00:00
self.all_reorder_map = {
2023-01-09 09:34:30 +00:00
i: i for i in range(len(self.trace_indice.indice_trace_list))
2023-01-06 09:31:59 +00:00
}
2023-01-06 09:13:18 +00:00
def _get_reorder_map(self, chunk_info):
2023-01-09 09:25:13 +00:00
reorder_map = {i: i for i in range(len(self.trace_indice.node_list))}
2023-01-06 09:13:18 +00:00
chunk_region_start = chunk_info["region"][0]
chunk_region_end = chunk_info["region"][1]
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
chunk_prepose_nodes_idx = [
2023-01-09 09:25:13 +00:00
find_idx_by_name(i.name, self.trace_indice.node_list)
2023-01-06 09:31:59 +00:00
for i in chunk_prepose_nodes
2023-01-06 09:13:18 +00:00
]
# put prepose nodes ahead
for idx, n in enumerate(chunk_prepose_nodes):
n_idx = chunk_prepose_nodes_idx[idx]
reorder_map[n_idx] = chunk_region_start + idx
# put other nodes after prepose nodes
2023-01-09 09:25:13 +00:00
for n in self.trace_indice.node_list[chunk_region_start : chunk_region_end + 1]:
2023-01-06 09:13:18 +00:00
if n in chunk_prepose_nodes:
continue
2023-01-09 09:25:13 +00:00
n_idx = find_idx_by_name(n.name, self.trace_indice.node_list)
2023-01-06 09:13:18 +00:00
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
reorder_map[n_idx] = n_idx + pos
return reorder_map
def _reorder_chunk_info(self, chunk_info, reorder_map):
# update chunk info
chunk_info["region"] = (
chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]),
chunk_info["region"][1],
)
new_inputs_dim = []
for idx, input_dim in enumerate(chunk_info["inputs_dim"]):
new_input_dim = {}
for k, v in input_dim.items():
new_input_dim[reorder_map[k]] = v
new_inputs_dim.append(new_input_dim)
chunk_info["inputs_dim"] = new_inputs_dim
return chunk_info
def _update_all_reorder_map(self, reorder_map):
for origin_idx, map_idx in self.all_reorder_map.items():
self.all_reorder_map[origin_idx] = reorder_map[map_idx]
def _reorder_self_node_list(self, reorder_map):
2023-01-09 09:25:13 +00:00
new_node_list = [None for _ in range(len(self.trace_indice.node_list))]
2023-01-06 09:13:18 +00:00
for old_idx, new_idx in reorder_map.items():
2023-01-09 09:25:13 +00:00
new_node_list[new_idx] = self.trace_indice.node_list[old_idx]
self.trace_indice.node_list = new_node_list
2023-01-06 09:13:18 +00:00
def _reorder_idx_trace(self, reorder_map):
# reorder list
2023-01-10 01:59:47 +00:00
new_idx_trace_list = [
None for _ in range(len(self.trace_indice.indice_trace_list))
]
2023-01-06 09:13:18 +00:00
for old_idx, new_idx in reorder_map.items():
2023-01-09 09:34:30 +00:00
new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx]
self.trace_indice.indice_trace_list = new_idx_trace_list
2023-01-06 09:13:18 +00:00
# update compute
2023-01-09 09:34:30 +00:00
for idx_trace in self.trace_indice.indice_trace_list:
2023-01-06 09:13:18 +00:00
compute = idx_trace["compute"]
for dim_compute in compute:
for idx, i in enumerate(dim_compute):
dim_compute[idx] = reorder_map[i]
# update source
2023-01-09 09:34:30 +00:00
for idx_trace in self.trace_indice.indice_trace_list:
2023-01-06 09:13:18 +00:00
source = idx_trace["source"]
for dim_idx, dim_source in enumerate(source):
new_dim_source = {}
for k, v in dim_source.items():
new_dim_source[reorder_map[k]] = v
source[dim_idx] = new_dim_source
def reorder_all(self, chunk_info):
if chunk_info is None:
return chunk_info
if len(chunk_info["args"]["prepose_nodes"]) == 0:
return chunk_info
reorder_map = self._get_reorder_map(chunk_info)
self._update_all_reorder_map(reorder_map)
self._reorder_idx_trace(reorder_map)
self._reorder_self_node_list(reorder_map)
chunk_info = self._reorder_chunk_info(chunk_info, reorder_map)
return chunk_info
def reorder_node_list(self, node_list):
new_node_list = [None for _ in range(len(node_list))]
for old_idx, new_idx in self.all_reorder_map.items():
new_node_list[new_idx] = node_list[old_idx]
return new_node_list
def tmp_reorder(self, node_list, chunk_info):
if len(chunk_info["args"]["prepose_nodes"]) == 0:
return node_list, chunk_info
reorder_map = self._get_reorder_map(chunk_info)
# new tmp node list
new_node_list = [None for _ in range(len(node_list))]
for old_idx, new_idx in reorder_map.items():
new_node_list[new_idx] = node_list[old_idx]
chunk_info = self._reorder_chunk_info(chunk_info, reorder_map)
return new_node_list, chunk_info