mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
420 lines
17 KiB
420 lines
17 KiB
from .trace_indice import TraceIndice |
|
from .utils import ( |
|
find_chunk_all_input_nodes, |
|
find_chunk_compute_input_and_output_nodes, |
|
find_idx_by_name, |
|
get_node_shape, |
|
is_non_compute_node, |
|
is_non_compute_node_except_placeholder, |
|
) |
|
|
|
|
|
class TraceFlow(object): |
|
def __init__(self, trace_indice: TraceIndice) -> None: |
|
self.trace_indice = trace_indice |
|
|
|
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node): |
|
""" |
|
Check 2 given index: one index should be source of the other |
|
Args: |
|
start_idx(int): start node chunk dim |
|
start_node(node): start node |
|
end_idx(int): end node chunk dim |
|
end_node(node): end node |
|
|
|
Returns: |
|
bool: True if check pass |
|
""" |
|
start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list) |
|
end_node_trace = self.trace_indice._find_trace_from_node(end_node) |
|
end_node_trace_source = end_node_trace["source"][end_dim] |
|
sorted_source = sorted( |
|
end_node_trace_source.items(), key=lambda d: d[0], reverse=True |
|
) |
|
for node_idx, node_dim in sorted_source: |
|
if node_idx == start_node_idx and start_dim in node_dim: |
|
return True |
|
# it means we meet a node outside the loop, and the node is not input node |
|
if node_idx < start_idx: |
|
return False |
|
return False |
|
|
|
def check_index_compute(self, start_idx, end_dim, end_node, end_idx): |
|
""" |
|
Check 2 given index: check they haven't been computed in the source trace. |
|
Args: |
|
start_idx(int): start node chunk dim |
|
start_node(node): start node |
|
end_idx(int): end node chunk dim |
|
end_node(node): end node |
|
|
|
Returns: |
|
bool: True if check pass |
|
""" |
|
end_node_trace = self.trace_indice._find_trace_from_node(end_node) |
|
end_node_compute = end_node_trace["compute"][end_dim] |
|
if any(start_idx <= i <= end_idx for i in end_node_compute): |
|
return False |
|
return True |
|
|
|
def get_node_chunk_dim(self, node_from, node_from_dim, node_to): |
|
node_from_source = self.trace_indice._find_source_trace_from_node(node_from) |
|
dim_source = node_from_source[node_from_dim] |
|
node_to_idx = find_idx_by_name(node_to.name, self.trace_indice.node_list) |
|
for k, v in dim_source.items(): |
|
if k == node_to_idx: |
|
return v |
|
return None |
|
|
|
def _find_inherit_dim(self, input_node, input_dim, node): |
|
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list) |
|
node_trace_source = self.trace_indice._find_source_trace_from_node(node) |
|
for node_dim in range(len(get_node_shape(node))): |
|
if ( |
|
input_node_idx in node_trace_source[node_dim] |
|
and input_dim[0] in node_trace_source[node_dim][input_node_idx] |
|
): |
|
return node_dim |
|
return None |
|
|
|
def check_index_duplicate(self, chunk_infos, return_dim=False): |
|
input_dim_after_node = {} |
|
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): |
|
for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): |
|
inherit_dim = self._find_inherit_dim( |
|
input_node, v, self.trace_indice.node_list[k] |
|
) |
|
if inherit_dim: |
|
input_dim_after_node[k] = inherit_dim |
|
|
|
for node in self.trace_indice.node_list[ |
|
chunk_infos["region"][0] : chunk_infos["region"][1] + 1 |
|
]: |
|
if is_non_compute_node_except_placeholder(node): |
|
continue |
|
count = 0 |
|
duplicate_dims = [] |
|
node_trace_source = self.trace_indice._find_source_trace_from_node(node) |
|
for node_dim in range(len(get_node_shape(node))): |
|
duplicate_dim = [] |
|
duplicate_flag = False |
|
dim_source = node_trace_source[node_dim] |
|
for k, v in dim_source.items(): |
|
if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]: |
|
if k in input_dim_after_node and input_dim_after_node[k] in v: |
|
duplicate_flag = True |
|
duplicate_dim.append((k, v)) |
|
duplicate_dims.append(duplicate_dim) |
|
if duplicate_flag: |
|
count += 1 |
|
|
|
if count > 1: |
|
if return_dim: |
|
return False, duplicate_dims |
|
else: |
|
return False |
|
if return_dim: |
|
return True, None |
|
else: |
|
return True |
|
|
|
def _assgin_single_node_flow( |
|
self, |
|
arg_node, |
|
start_idx, |
|
end_idx, |
|
cur_node_dim, |
|
cur_node_compute, |
|
cur_node_source, |
|
cur_node_fix_dim, |
|
all_node_info, |
|
next_node_list, |
|
): |
|
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list) |
|
# arg in chunk range or be inputs |
|
if not (start_idx <= arg_idx < end_idx): |
|
return True |
|
|
|
# find arg dim |
|
if cur_node_dim is not None: |
|
# dim is computed |
|
if arg_idx in cur_node_compute[cur_node_dim]: |
|
return False |
|
if arg_idx not in cur_node_source[cur_node_dim]: |
|
arg_dim = None |
|
else: |
|
arg_dim = cur_node_source[cur_node_dim][arg_idx][0] |
|
else: |
|
arg_dim = None |
|
|
|
# get fix dim |
|
arg_fix_dim = [] |
|
if cur_node_dim is not None: |
|
for i in cur_node_fix_dim: |
|
fix_dim_source = cur_node_source[i] |
|
if arg_idx in fix_dim_source: |
|
arg_fix_dim.append(fix_dim_source[arg_idx][0]) |
|
|
|
# if already in node_info, arg dim must be same |
|
if arg_node in all_node_info: |
|
if all_node_info[arg_node]["chunk_dim"] != arg_dim: |
|
return False |
|
all_node_info[arg_node]["fix_dim"] = list( |
|
set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim) |
|
) |
|
# else add it to list |
|
else: |
|
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim} |
|
|
|
next_node_list.append(arg_node) |
|
return True |
|
|
|
def _get_all_node_info(self, end_dim, start_idx, end_idx): |
|
cur_node_list = [ |
|
self.trace_indice.node_list[end_idx] |
|
] # start from the last node |
|
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} |
|
|
|
while len(cur_node_list) > 0: |
|
next_node_list = [] |
|
|
|
for cur_node in cur_node_list: |
|
# get cur node info |
|
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] |
|
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] |
|
if cur_node_chunk_dim: |
|
cur_node_compute = self.trace_indice._find_compute_trace_from_node( |
|
cur_node |
|
) |
|
cur_node_source = self.trace_indice._find_source_trace_from_node( |
|
cur_node |
|
) |
|
else: |
|
cur_node_compute = cur_node_source = None |
|
|
|
# get all valid args |
|
arg_list = [] |
|
for arg in cur_node.args: |
|
if type(arg) != type(cur_node): |
|
continue |
|
if is_non_compute_node(arg): |
|
continue |
|
arg_list.append(arg) |
|
flow_flag = self._assgin_single_node_flow( |
|
arg, |
|
start_idx, |
|
end_idx, |
|
cur_node_chunk_dim, |
|
cur_node_compute, |
|
cur_node_source, |
|
cur_node_fix_dim, |
|
all_node_info, |
|
next_node_list, |
|
) |
|
if flow_flag == False: |
|
return None |
|
|
|
if len(arg_list) == 2: |
|
if any(i in cur_node.name for i in ["add", "mul"]): |
|
for arg in arg_list: |
|
if not ( |
|
start_idx |
|
<= find_idx_by_name( |
|
arg.name, self.trace_indice.node_list |
|
) |
|
< end_idx |
|
): |
|
continue |
|
arg_chunk_dim = all_node_info[arg]["chunk_dim"] |
|
arg_fix_dim = all_node_info[arg]["fix_dim"] |
|
arg_shape = get_node_shape(arg) |
|
# add all dim as fix dim except chunk dim |
|
for i, shape in enumerate(arg_shape): |
|
if shape != 1 and i != cur_node_chunk_dim: |
|
if i == arg_chunk_dim: |
|
return None |
|
if i not in arg_fix_dim: |
|
arg_fix_dim.append(i) |
|
elif "einsum" in cur_node.name: |
|
pass |
|
elif "matmul" in cur_node.name: |
|
pass |
|
else: |
|
raise NotImplementedError() |
|
cur_node_list = next_node_list |
|
return all_node_info |
|
|
|
def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info): |
|
inputs_dim = [] |
|
remove_inputs = [] |
|
for input_node in inputs: |
|
input_dict = {} |
|
input_node_idx = find_idx_by_name( |
|
input_node.name, self.trace_indice.node_list |
|
) |
|
for user in input_node.users.keys(): |
|
if is_non_compute_node(user): |
|
continue |
|
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list) |
|
if start_idx <= user_idx <= end_idx: |
|
chunk_dim = all_node_info[user]["chunk_dim"] |
|
if chunk_dim is not None: |
|
user_source = self.trace_indice._find_source_trace_from_node( |
|
user |
|
)[chunk_dim] |
|
if input_node_idx in user_source: |
|
input_dict[user_idx] = user_source[input_node_idx] |
|
else: |
|
return None, None |
|
if len(input_dict) == 0: |
|
remove_inputs.append(input_node) |
|
else: |
|
inputs_dim.append(input_dict) |
|
for i in remove_inputs: |
|
if i in inputs: |
|
inputs.remove(i) |
|
return inputs, inputs_dim |
|
|
|
def _get_prepose_nodes(self, all_node_info, start_idx, end_idx): |
|
# get all possible prepose nodes |
|
maybe_prepose_nodes = [] |
|
for node, node_info in all_node_info.items(): |
|
if node_info["chunk_dim"] is None: |
|
maybe_prepose_nodes.append(node) |
|
maybe_prepose_nodes.sort( |
|
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list), |
|
reverse=True, |
|
) # from last node to first node |
|
prepose_nodes = [] |
|
# set every node as root, search its args, if all legal, turn root and args as prepose nodes |
|
while len(maybe_prepose_nodes) > 0: |
|
tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]] |
|
tmp_cur_related_prepose_nodes = [] |
|
prepose_flag = True |
|
|
|
# loop cur node's all arg until out of chunk |
|
while len(tmp_cur_prepose_nodes) > 0: |
|
if prepose_flag == False: |
|
break |
|
tmp_next_prepose_nodes = [] |
|
tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes) |
|
for cur_prepose_node in tmp_cur_prepose_nodes: |
|
if prepose_flag == False: |
|
break |
|
for cur_prepose_node_arg in cur_prepose_node.args: |
|
if type(cur_prepose_node_arg) != type(cur_prepose_node): |
|
continue |
|
# out of loop |
|
if not ( |
|
start_idx |
|
<= find_idx_by_name( |
|
cur_prepose_node_arg.name, self.trace_indice.node_list |
|
) |
|
< end_idx |
|
): |
|
continue |
|
# compute op in loop |
|
elif cur_prepose_node_arg in all_node_info: |
|
if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None: |
|
tmp_next_prepose_nodes.append(cur_prepose_node_arg) |
|
else: |
|
prepose_flag = False |
|
break |
|
# non compute op |
|
else: |
|
tmp_next_prepose_nodes.append(cur_prepose_node_arg) |
|
tmp_cur_prepose_nodes = tmp_next_prepose_nodes |
|
|
|
if prepose_flag == False: |
|
maybe_prepose_nodes.remove(maybe_prepose_nodes[0]) |
|
continue |
|
else: |
|
for n in tmp_cur_related_prepose_nodes: |
|
if n not in prepose_nodes: |
|
prepose_nodes.append(n) |
|
if n in maybe_prepose_nodes: |
|
maybe_prepose_nodes.remove(n) |
|
# sort by index |
|
prepose_nodes.sort( |
|
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list) |
|
) |
|
|
|
return prepose_nodes |
|
|
|
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): |
|
# we need to log input nodes to avoid deleteing them in the loop |
|
chunk_node_list = self.trace_indice.node_list[start_idx : end_idx + 1] |
|
# also need to get some prepose node's arg out of non_chunk_inputs |
|
for n in chunk_info["args"]["prepose_nodes"]: |
|
chunk_node_list.remove(n) |
|
non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list) |
|
for i in non_chunk_inputs: |
|
if i not in chunk_info["inputs"]: |
|
chunk_info["inputs_non_chunk"].append(i) |
|
return chunk_info |
|
|
|
def flow_search(self, start_idx, start_dim, end_idx, end_dim): |
|
inputs, outputs = find_chunk_compute_input_and_output_nodes( |
|
self.trace_indice.node_list[start_idx : end_idx + 1] |
|
) |
|
# only single ouput |
|
if len(outputs) > 1: |
|
return None |
|
|
|
# get every node's chunk dim and fix dim |
|
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) |
|
if all_node_info is None: |
|
return None |
|
|
|
# get input nodes' chunk dim |
|
inputs, inputs_dim = self._get_input_nodes_dim( |
|
inputs, start_idx, end_idx, all_node_info |
|
) |
|
if inputs is None: |
|
return None |
|
|
|
chunk_info = { |
|
"region": (start_idx, end_idx), |
|
"inputs": inputs, |
|
"inputs_non_chunk": [], |
|
"inputs_dim": inputs_dim, |
|
"outputs": outputs, |
|
"outputs_dim": end_dim, |
|
"node_chunk_dim": all_node_info, |
|
"args": {}, |
|
} |
|
|
|
# move useless nodes ahead of loop |
|
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes( |
|
all_node_info, start_idx, end_idx |
|
) |
|
|
|
# find non chunk inputs |
|
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx) |
|
|
|
# reassgin reshape size, some size may have changed due to chunk |
|
chunk_info = self._reassgin_reshape_size(chunk_info) |
|
|
|
return chunk_info |
|
|
|
def _reassgin_reshape_size(self, chunk_info): |
|
chunk_region = chunk_info["region"] |
|
reshape_size = {} |
|
chunk_shape = get_node_shape(chunk_info["outputs"][0])[ |
|
chunk_info["outputs_dim"] |
|
] |
|
for node in self.trace_indice.node_list[chunk_region[0] : chunk_region[1] + 1]: |
|
if any(i in node.name for i in ["reshape", "view"]): |
|
reshape_args = node.args[1:] |
|
reshape_log = self.trace_indice.indice_view_list[node] |
|
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] |
|
reshape_size[node.name] = {} |
|
for reshape_arg_dim, reshape_arg in enumerate(reshape_args): |
|
if reshape_arg_dim in reshape_log["dim_to"]: |
|
continue |
|
if reshape_arg_dim == chunk_dim: |
|
reshape_size[node.name][reshape_arg.name] = ( |
|
"min(chunk_size, %d - chunk_idx)" % chunk_shape |
|
) |
|
chunk_info["reshape_size"] = reshape_size |
|
return chunk_info
|
|
|