mirror of https://github.com/hpcaitech/ColossalAI
seperate flow tracer
parent
fd87d78a28
commit
ae27a8b26d
|
@ -745,14 +745,7 @@ class IndexTracer(object):
|
|||
next_node_list.append(arg_node)
|
||||
return True
|
||||
|
||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||
self.node_list[start_idx : end_idx + 1]
|
||||
)
|
||||
# only single ouput
|
||||
if len(outputs) > 1:
|
||||
return None
|
||||
|
||||
def _get_all_node_info(self, end_dim, start_idx, end_idx):
|
||||
cur_node_list = [self.node_list[end_idx]] # start from the last node
|
||||
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||
|
||||
|
@ -763,7 +756,6 @@ class IndexTracer(object):
|
|||
# get cur node info
|
||||
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
||||
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
||||
cur_node_idx = find_idx_by_name(cur_node.name, self.node_list)
|
||||
if cur_node_chunk_dim:
|
||||
cur_node_compute = self._find_compute_trace_from_node(cur_node)
|
||||
cur_node_source = self._find_source_trace_from_node(cur_node)
|
||||
|
@ -818,6 +810,20 @@ class IndexTracer(object):
|
|||
else:
|
||||
raise NotImplementedError()
|
||||
cur_node_list = next_node_list
|
||||
return all_node_info
|
||||
|
||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||
self.node_list[start_idx : end_idx + 1]
|
||||
)
|
||||
# only single ouput
|
||||
if len(outputs) > 1:
|
||||
return None
|
||||
|
||||
# get every node's chunk dim and fix dim
|
||||
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
|
||||
if all_node_info is None:
|
||||
return None
|
||||
|
||||
inputs_dim = []
|
||||
remove_inputs = []
|
||||
|
|
Loading…
Reference in New Issue