mirror of https://github.com/hpcaitech/ColossalAI
seperate non chunk input
parent
f856611d21
commit
6685a9d022
|
@ -839,7 +839,7 @@ class IndexTracer(object):
|
||||||
inputs.remove(i)
|
inputs.remove(i)
|
||||||
return inputs, inputs_dim
|
return inputs, inputs_dim
|
||||||
|
|
||||||
def _set_prepose_nodes(self, all_node_info, start_idx, end_idx):
|
def _get_prepose_nodes(self, all_node_info, start_idx, end_idx):
|
||||||
# get all possible prepose nodes
|
# get all possible prepose nodes
|
||||||
maybe_prepose_nodes = []
|
maybe_prepose_nodes = []
|
||||||
for node, node_info in all_node_info.items():
|
for node, node_info in all_node_info.items():
|
||||||
|
@ -903,6 +903,18 @@ class IndexTracer(object):
|
||||||
|
|
||||||
return prepose_nodes
|
return prepose_nodes
|
||||||
|
|
||||||
|
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
|
||||||
|
# we need to log input nodes to avoid deleteing them in the loop
|
||||||
|
chunk_node_list = self.node_list[start_idx : end_idx + 1]
|
||||||
|
# also need to get some prepose node's arg out of non_chunk_inputs
|
||||||
|
for n in chunk_info["args"]["prepose_nodes"]:
|
||||||
|
chunk_node_list.remove(n)
|
||||||
|
non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list)
|
||||||
|
for i in non_chunk_inputs:
|
||||||
|
if i not in chunk_info["inputs"]:
|
||||||
|
chunk_info["inputs_non_chunk"].append(i)
|
||||||
|
return chunk_info
|
||||||
|
|
||||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||||
self.node_list[start_idx : end_idx + 1]
|
self.node_list[start_idx : end_idx + 1]
|
||||||
|
@ -917,7 +929,9 @@ class IndexTracer(object):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# get input nodes' chunk dim
|
# get input nodes' chunk dim
|
||||||
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
|
inputs, inputs_dim = self._get_input_nodes_dim(
|
||||||
|
inputs, start_idx, end_idx, all_node_info
|
||||||
|
)
|
||||||
if inputs is None:
|
if inputs is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -933,17 +947,12 @@ class IndexTracer(object):
|
||||||
}
|
}
|
||||||
|
|
||||||
# move useless nodes ahead of loop
|
# move useless nodes ahead of loop
|
||||||
chunk_info["args"]["prepose_nodes"] = self._set_prepose_nodes(all_node_info, start_idx, end_idx)
|
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(
|
||||||
|
all_node_info, start_idx, end_idx
|
||||||
|
)
|
||||||
|
|
||||||
# we need to log input nodes to avoid deleteing them in the loop
|
# find non chunk inputs
|
||||||
chunk_node_list = self.node_list[start_idx : end_idx + 1]
|
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
|
||||||
# also need to get some prepose node's arg out of non_chunk_inputs
|
|
||||||
for n in chunk_info["args"]["prepose_nodes"]:
|
|
||||||
chunk_node_list.remove(n)
|
|
||||||
non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list)
|
|
||||||
for i in non_chunk_inputs:
|
|
||||||
if i not in chunk_info["inputs"]:
|
|
||||||
chunk_info["inputs_non_chunk"].append(i)
|
|
||||||
|
|
||||||
# reassgin reshape size, some size may have changed due to chunk
|
# reassgin reshape size, some size may have changed due to chunk
|
||||||
chunk_info = self._reassgin_reshape_size(chunk_info)
|
chunk_info = self._reassgin_reshape_size(chunk_info)
|
||||||
|
|
Loading…
Reference in New Issue