mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
111 lines
4.6 KiB
111 lines
4.6 KiB
from .trace_indice import TraceIndice |
|
from .utils import NodeMgr |
|
|
|
|
|
class ReorderGraph(object): |
|
""" |
|
Reorder node list and indice trace list |
|
""" |
|
|
|
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None: |
|
self.trace_indice = trace_indice |
|
self.node_mgr = node_mgr |
|
self.all_reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))} |
|
|
|
def _get_reorder_map(self, chunk_info): |
|
reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))} |
|
|
|
chunk_region_start = chunk_info["region"][0] |
|
chunk_region_end = chunk_info["region"][1] |
|
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] |
|
chunk_prepose_nodes_idx = [self.node_mgr.find_node_idx(i) for i in chunk_prepose_nodes] |
|
# put prepose nodes ahead |
|
for idx, n in enumerate(chunk_prepose_nodes): |
|
n_idx = chunk_prepose_nodes_idx[idx] |
|
reorder_map[n_idx] = chunk_region_start + idx |
|
# put other nodes after prepose nodes |
|
for n in self.node_mgr.get_node_slice_by_idx(chunk_region_start, chunk_region_end + 1): |
|
if n in chunk_prepose_nodes: |
|
continue |
|
n_idx = self.node_mgr.find_node_idx(n) |
|
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) |
|
reorder_map[n_idx] = n_idx + pos |
|
|
|
return reorder_map |
|
|
|
def _reorder_chunk_info(self, chunk_info, reorder_map): |
|
# update chunk info |
|
chunk_info["region"] = ( |
|
chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]), |
|
chunk_info["region"][1], |
|
) |
|
new_inputs_dim = [] |
|
for _, input_dim in enumerate(chunk_info["inputs_dim"]): |
|
new_input_dim = {} |
|
for k, v in input_dim.items(): |
|
new_input_dim[reorder_map[k]] = v |
|
new_inputs_dim.append(new_input_dim) |
|
chunk_info["inputs_dim"] = new_inputs_dim |
|
return chunk_info |
|
|
|
def _update_all_reorder_map(self, reorder_map): |
|
for origin_idx, map_idx in self.all_reorder_map.items(): |
|
self.all_reorder_map[origin_idx] = reorder_map[map_idx] |
|
|
|
def _reorder_self_node_list(self, reorder_map): |
|
new_node_list = [None for _ in range(len(self.node_mgr.get_node_list()))] |
|
for old_idx, new_idx in reorder_map.items(): |
|
new_node_list[new_idx] = self.node_mgr.get_node_by_idx(old_idx) |
|
self.node_mgr.update_node_list(new_node_list) |
|
|
|
def _reorder_idx_trace(self, reorder_map): |
|
# reorder list |
|
new_idx_trace_list = [None for _ in range(len(self.trace_indice.indice_trace_list))] |
|
for old_idx, new_idx in reorder_map.items(): |
|
new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx] |
|
self.trace_indice.indice_trace_list = new_idx_trace_list |
|
# update compute |
|
for idx_trace in self.trace_indice.indice_trace_list: |
|
compute = idx_trace["compute"] |
|
for dim_compute in compute: |
|
for idx, i in enumerate(dim_compute): |
|
dim_compute[idx] = reorder_map[i] |
|
# update source |
|
for idx_trace in self.trace_indice.indice_trace_list: |
|
source = idx_trace["source"] |
|
for dim_idx, dim_source in enumerate(source): |
|
new_dim_source = {} |
|
for k, v in dim_source.items(): |
|
new_dim_source[reorder_map[k]] = v |
|
source[dim_idx] = new_dim_source |
|
|
|
def reorder_all(self, chunk_info): |
|
if chunk_info is None: |
|
return chunk_info |
|
if len(chunk_info["args"]["prepose_nodes"]) == 0: |
|
return chunk_info |
|
reorder_map = self._get_reorder_map(chunk_info) |
|
self._update_all_reorder_map(reorder_map) |
|
self._reorder_idx_trace(reorder_map) |
|
self._reorder_self_node_list(reorder_map) |
|
chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) |
|
return chunk_info |
|
|
|
def reorder_node_list(self, node_list): |
|
new_node_list = [None for _ in range(len(node_list))] |
|
for old_idx, new_idx in self.all_reorder_map.items(): |
|
new_node_list[new_idx] = node_list[old_idx] |
|
return new_node_list |
|
|
|
def tmp_reorder(self, node_list, chunk_info): |
|
if len(chunk_info["args"]["prepose_nodes"]) == 0: |
|
return node_list, chunk_info |
|
reorder_map = self._get_reorder_map(chunk_info) |
|
|
|
# new tmp node list |
|
new_node_list = [None for _ in range(len(node_list))] |
|
for old_idx, new_idx in reorder_map.items(): |
|
new_node_list[new_idx] = node_list[old_idx] |
|
|
|
chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) |
|
return new_node_list, chunk_info
|
|
|