mirror of https://github.com/hpcaitech/ColossalAI
code style
parent
774d34f1aa
commit
522f017418
|
@ -1004,7 +1004,7 @@ class FlowTracer(object):
|
|||
|
||||
# if already in node_info, arg dim must be same
|
||||
if arg_node in all_node_info:
|
||||
if all_node_info[arg_node]['chunk_dim'] != arg_dim:
|
||||
if all_node_info[arg_node]["chunk_dim"] != arg_dim:
|
||||
return False
|
||||
all_node_info[arg_node]["fix_dim"] = list(
|
||||
set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)
|
||||
|
@ -1132,16 +1132,19 @@ class FlowTracer(object):
|
|||
# get all possible prepose nodes
|
||||
maybe_prepose_nodes = []
|
||||
for node, node_info in all_node_info.items():
|
||||
if node_info['chunk_dim'] is None:
|
||||
if node_info["chunk_dim"] is None:
|
||||
maybe_prepose_nodes.append(node)
|
||||
maybe_prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list), reverse=True) # from last node to first node
|
||||
maybe_prepose_nodes.sort(
|
||||
key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list),
|
||||
reverse=True,
|
||||
) # from last node to first node
|
||||
prepose_nodes = []
|
||||
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
|
||||
while len(maybe_prepose_nodes) > 0:
|
||||
tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]]
|
||||
tmp_cur_related_prepose_nodes = []
|
||||
prepose_flag = True
|
||||
|
||||
|
||||
# loop cur node's all arg until out of chunk
|
||||
while len(tmp_cur_prepose_nodes) > 0:
|
||||
tmp_next_prepose_nodes = []
|
||||
|
@ -1151,20 +1154,28 @@ class FlowTracer(object):
|
|||
if type(cur_prepose_node_arg) != type(cur_prepose_node):
|
||||
continue
|
||||
# out of loop
|
||||
if not (start_idx <= _find_idx_by_name(cur_prepose_node_arg.name, self.node_list) < end_idx):
|
||||
if not (
|
||||
start_idx
|
||||
<= _find_idx_by_name(
|
||||
cur_prepose_node_arg.name, self.node_list
|
||||
)
|
||||
< end_idx
|
||||
):
|
||||
continue
|
||||
# compute op in loop
|
||||
elif cur_prepose_node_arg in all_node_info:
|
||||
if all_node_info[cur_prepose_node_arg]['chunk_dim'] is None:
|
||||
if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None:
|
||||
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
|
||||
else:
|
||||
prepose_flag = False
|
||||
break; break; break
|
||||
break
|
||||
break
|
||||
break
|
||||
# non compute op
|
||||
else:
|
||||
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
|
||||
tmp_cur_prepose_nodes = tmp_next_prepose_nodes
|
||||
|
||||
|
||||
if prepose_flag == False:
|
||||
maybe_prepose_nodes.remove(maybe_prepose_nodes[0])
|
||||
continue
|
||||
|
@ -1175,21 +1186,21 @@ class FlowTracer(object):
|
|||
if n in maybe_prepose_nodes:
|
||||
maybe_prepose_nodes.remove(n)
|
||||
# sort by index
|
||||
prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list))
|
||||
prepose_nodes.sort(
|
||||
key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list)
|
||||
)
|
||||
chunk_info["args"]["prepose_nodes"] = prepose_nodes
|
||||
|
||||
|
||||
# we need to log input nodes to avoid deleteing them in the loop
|
||||
chunk_node_list = self.node_list[start_idx : end_idx + 1]
|
||||
# also need to get some prepose node's arg out of non_chunk_inputs
|
||||
for n in prepose_nodes:
|
||||
chunk_node_list.remove(n)
|
||||
non_chunk_inputs = _find_chunk_all_input_nodes(
|
||||
chunk_node_list
|
||||
)
|
||||
non_chunk_inputs = _find_chunk_all_input_nodes(chunk_node_list)
|
||||
for i in non_chunk_inputs:
|
||||
if i not in chunk_info["inputs"] and i not in prepose_nodes:
|
||||
chunk_info["inputs_non_chunk"].append(i)
|
||||
|
||||
|
||||
return chunk_info
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue