mirror of https://github.com/hpcaitech/ColossalAI
seperate input node dim search
parent
ae27a8b26d
commit
f4a1607e56
|
@ -812,19 +812,7 @@ class IndexTracer(object):
|
||||||
cur_node_list = next_node_list
|
cur_node_list = next_node_list
|
||||||
return all_node_info
|
return all_node_info
|
||||||
|
|
||||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info):
|
||||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
|
||||||
self.node_list[start_idx : end_idx + 1]
|
|
||||||
)
|
|
||||||
# only single ouput
|
|
||||||
if len(outputs) > 1:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# get every node's chunk dim and fix dim
|
|
||||||
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
|
|
||||||
if all_node_info is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
inputs_dim = []
|
inputs_dim = []
|
||||||
remove_inputs = []
|
remove_inputs = []
|
||||||
for input_node in inputs:
|
for input_node in inputs:
|
||||||
|
@ -841,7 +829,7 @@ class IndexTracer(object):
|
||||||
if input_node_idx in user_source:
|
if input_node_idx in user_source:
|
||||||
input_dict[user_idx] = user_source[input_node_idx]
|
input_dict[user_idx] = user_source[input_node_idx]
|
||||||
else:
|
else:
|
||||||
return None
|
return None, None
|
||||||
if len(input_dict) == 0:
|
if len(input_dict) == 0:
|
||||||
remove_inputs.append(input_node)
|
remove_inputs.append(input_node)
|
||||||
else:
|
else:
|
||||||
|
@ -849,6 +837,25 @@ class IndexTracer(object):
|
||||||
for i in remove_inputs:
|
for i in remove_inputs:
|
||||||
if i in inputs:
|
if i in inputs:
|
||||||
inputs.remove(i)
|
inputs.remove(i)
|
||||||
|
return inputs, inputs_dim
|
||||||
|
|
||||||
|
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||||
|
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||||
|
self.node_list[start_idx : end_idx + 1]
|
||||||
|
)
|
||||||
|
# only single ouput
|
||||||
|
if len(outputs) > 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# get every node's chunk dim and fix dim
|
||||||
|
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
|
||||||
|
if all_node_info is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# get input nodes' chunk dim
|
||||||
|
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
|
||||||
|
if inputs is None:
|
||||||
|
return None
|
||||||
|
|
||||||
chunk_info = {
|
chunk_info = {
|
||||||
"region": (start_idx, end_idx),
|
"region": (start_idx, end_idx),
|
||||||
|
|
Loading…
Reference in New Issue