mirror of https://github.com/hpcaitech/ColossalAI
seperate flow tracer
parent
fd87d78a28
commit
ae27a8b26d
|
@ -745,14 +745,7 @@ class IndexTracer(object):
|
||||||
next_node_list.append(arg_node)
|
next_node_list.append(arg_node)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
def _get_all_node_info(self, end_dim, start_idx, end_idx):
|
||||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
|
||||||
self.node_list[start_idx : end_idx + 1]
|
|
||||||
)
|
|
||||||
# only single ouput
|
|
||||||
if len(outputs) > 1:
|
|
||||||
return None
|
|
||||||
|
|
||||||
cur_node_list = [self.node_list[end_idx]] # start from the last node
|
cur_node_list = [self.node_list[end_idx]] # start from the last node
|
||||||
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||||
|
|
||||||
|
@ -763,7 +756,6 @@ class IndexTracer(object):
|
||||||
# get cur node info
|
# get cur node info
|
||||||
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
||||||
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
||||||
cur_node_idx = find_idx_by_name(cur_node.name, self.node_list)
|
|
||||||
if cur_node_chunk_dim:
|
if cur_node_chunk_dim:
|
||||||
cur_node_compute = self._find_compute_trace_from_node(cur_node)
|
cur_node_compute = self._find_compute_trace_from_node(cur_node)
|
||||||
cur_node_source = self._find_source_trace_from_node(cur_node)
|
cur_node_source = self._find_source_trace_from_node(cur_node)
|
||||||
|
@ -818,6 +810,20 @@ class IndexTracer(object):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
cur_node_list = next_node_list
|
cur_node_list = next_node_list
|
||||||
|
return all_node_info
|
||||||
|
|
||||||
|
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||||
|
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||||
|
self.node_list[start_idx : end_idx + 1]
|
||||||
|
)
|
||||||
|
# only single ouput
|
||||||
|
if len(outputs) > 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# get every node's chunk dim and fix dim
|
||||||
|
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
|
||||||
|
if all_node_info is None:
|
||||||
|
return None
|
||||||
|
|
||||||
inputs_dim = []
|
inputs_dim = []
|
||||||
remove_inputs = []
|
remove_inputs = []
|
||||||
|
|
Loading…
Reference in New Issue