mirror of https://github.com/hpcaitech/ColossalAI
support check_index_duplicate
parent
8754fa2553
commit
1e0fd11bc1
|
@ -179,7 +179,12 @@ class FlowTracer(object):
|
|||
"outputs_dim": end_dim,
|
||||
"args": {},
|
||||
}
|
||||
flow_flag = False
|
||||
flow_block = False
|
||||
|
||||
# TODO don't allow multi outputs now
|
||||
if len(outputs) > 1:
|
||||
flow_block = True
|
||||
return flow_block, chunk_info
|
||||
|
||||
for idx in range(start_idx, end_idx + 1):
|
||||
node = self.node_list[idx]
|
||||
|
@ -199,10 +204,10 @@ class FlowTracer(object):
|
|||
self.node_list[end_idx], end_dim, node
|
||||
)
|
||||
if mix_flow_node_dim is None:
|
||||
flow_flag = True
|
||||
flow_block = True
|
||||
break
|
||||
if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1:
|
||||
flow_flag = False
|
||||
flow_block = False
|
||||
for i in self._get_same_flow_node(
|
||||
chunk_info["inputs"], mix_flow_node
|
||||
):
|
||||
|
@ -210,11 +215,15 @@ class FlowTracer(object):
|
|||
# else, we need to chunk mix var as well
|
||||
else:
|
||||
# TODO chunk another value
|
||||
flow_flag = True
|
||||
flow_block = True
|
||||
break
|
||||
else:
|
||||
raise NotImplementedError("%s not implemented" % node.name)
|
||||
|
||||
if flow_block:
|
||||
flow_block = True
|
||||
return flow_block, chunk_info
|
||||
|
||||
inputs_dim = []
|
||||
remove_inputs = []
|
||||
for input_node in chunk_info["inputs"]:
|
||||
|
@ -250,7 +259,7 @@ class FlowTracer(object):
|
|||
if i not in chunk_info["inputs"]:
|
||||
chunk_info["inputs_non_chunk"].append(i)
|
||||
|
||||
return flow_flag, chunk_info
|
||||
return flow_block, chunk_info
|
||||
|
||||
|
||||
class IndexTracer(object):
|
||||
|
@ -869,14 +878,6 @@ class IndexTracer(object):
|
|||
if any(start_idx <= i <= end_idx for i in end_node_compute):
|
||||
return False
|
||||
return True
|
||||
# end_node_trace_source = end_node_trace['source'][end_dim]
|
||||
# for node_idx, node_dim in end_node_trace_source.items():
|
||||
# if node_idx < start_node_idx or node_idx > end_node_idx:
|
||||
# continue
|
||||
# compute_list = self.idx_trace_list[node_idx]['compute'][node_dim]
|
||||
# if any(start_node_idx <= i <= end_node_idx for i in compute_list):
|
||||
# return False
|
||||
# return True
|
||||
|
||||
def _get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
||||
node_from_source = self._find_source_trace_from_node(node_from)
|
||||
|
@ -1240,10 +1241,10 @@ class ChunkRegionSearch(object):
|
|||
):
|
||||
continue
|
||||
# detect flow meet
|
||||
flow_flag, chunk_info = self.flow_tracer._detect_flow(
|
||||
flow_block, chunk_info = self.flow_tracer._detect_flow(
|
||||
start_idx, start_dim, end_idx, end_dim, self.index_tracer
|
||||
)
|
||||
if flow_flag:
|
||||
if flow_block:
|
||||
continue
|
||||
chunk_infos.append(chunk_info)
|
||||
chunk_infos = self._check_duplicate_map(chunk_infos)
|
||||
|
|
Loading…
Reference in New Issue