mirror of https://github.com/hpcaitech/ColossalAI
seperate prepose_nodes
parent
f4a1607e56
commit
f856611d21
|
@ -839,36 +839,7 @@ class IndexTracer(object):
|
|||
inputs.remove(i)
|
||||
return inputs, inputs_dim
|
||||
|
||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||
self.node_list[start_idx : end_idx + 1]
|
||||
)
|
||||
# only single ouput
|
||||
if len(outputs) > 1:
|
||||
return None
|
||||
|
||||
# get every node's chunk dim and fix dim
|
||||
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
|
||||
if all_node_info is None:
|
||||
return None
|
||||
|
||||
# get input nodes' chunk dim
|
||||
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
|
||||
if inputs is None:
|
||||
return None
|
||||
|
||||
chunk_info = {
|
||||
"region": (start_idx, end_idx),
|
||||
"inputs": inputs,
|
||||
"inputs_non_chunk": [],
|
||||
"inputs_dim": inputs_dim,
|
||||
"outputs": outputs,
|
||||
"outputs_dim": end_dim,
|
||||
"node_chunk_dim": all_node_info,
|
||||
"args": {},
|
||||
}
|
||||
|
||||
# move useless nodes ahead of loop
|
||||
def _set_prepose_nodes(self, all_node_info, start_idx, end_idx):
|
||||
# get all possible prepose nodes
|
||||
maybe_prepose_nodes = []
|
||||
for node, node_info in all_node_info.items():
|
||||
|
@ -929,12 +900,45 @@ class IndexTracer(object):
|
|||
maybe_prepose_nodes.remove(n)
|
||||
# sort by index
|
||||
prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.node_list))
|
||||
chunk_info["args"]["prepose_nodes"] = prepose_nodes
|
||||
|
||||
return prepose_nodes
|
||||
|
||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||
self.node_list[start_idx : end_idx + 1]
|
||||
)
|
||||
# only single ouput
|
||||
if len(outputs) > 1:
|
||||
return None
|
||||
|
||||
# get every node's chunk dim and fix dim
|
||||
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
|
||||
if all_node_info is None:
|
||||
return None
|
||||
|
||||
# get input nodes' chunk dim
|
||||
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
|
||||
if inputs is None:
|
||||
return None
|
||||
|
||||
chunk_info = {
|
||||
"region": (start_idx, end_idx),
|
||||
"inputs": inputs,
|
||||
"inputs_non_chunk": [],
|
||||
"inputs_dim": inputs_dim,
|
||||
"outputs": outputs,
|
||||
"outputs_dim": end_dim,
|
||||
"node_chunk_dim": all_node_info,
|
||||
"args": {},
|
||||
}
|
||||
|
||||
# move useless nodes ahead of loop
|
||||
chunk_info["args"]["prepose_nodes"] = self._set_prepose_nodes(all_node_info, start_idx, end_idx)
|
||||
|
||||
# we need to log input nodes to avoid deleteing them in the loop
|
||||
chunk_node_list = self.node_list[start_idx : end_idx + 1]
|
||||
# also need to get some prepose node's arg out of non_chunk_inputs
|
||||
for n in prepose_nodes:
|
||||
for n in chunk_info["args"]["prepose_nodes"]:
|
||||
chunk_node_list.remove(n)
|
||||
non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list)
|
||||
for i in non_chunk_inputs:
|
||||
|
|
Loading…
Reference in New Issue