mirror of https://github.com/hpcaitech/ColossalAI
support check_index_duplicate
parent
8754fa2553
commit
1e0fd11bc1
|
@ -179,7 +179,12 @@ class FlowTracer(object):
|
||||||
"outputs_dim": end_dim,
|
"outputs_dim": end_dim,
|
||||||
"args": {},
|
"args": {},
|
||||||
}
|
}
|
||||||
flow_flag = False
|
flow_block = False
|
||||||
|
|
||||||
|
# TODO don't allow multi outputs now
|
||||||
|
if len(outputs) > 1:
|
||||||
|
flow_block = True
|
||||||
|
return flow_block, chunk_info
|
||||||
|
|
||||||
for idx in range(start_idx, end_idx + 1):
|
for idx in range(start_idx, end_idx + 1):
|
||||||
node = self.node_list[idx]
|
node = self.node_list[idx]
|
||||||
|
@ -199,10 +204,10 @@ class FlowTracer(object):
|
||||||
self.node_list[end_idx], end_dim, node
|
self.node_list[end_idx], end_dim, node
|
||||||
)
|
)
|
||||||
if mix_flow_node_dim is None:
|
if mix_flow_node_dim is None:
|
||||||
flow_flag = True
|
flow_block = True
|
||||||
break
|
break
|
||||||
if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1:
|
if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1:
|
||||||
flow_flag = False
|
flow_block = False
|
||||||
for i in self._get_same_flow_node(
|
for i in self._get_same_flow_node(
|
||||||
chunk_info["inputs"], mix_flow_node
|
chunk_info["inputs"], mix_flow_node
|
||||||
):
|
):
|
||||||
|
@ -210,11 +215,15 @@ class FlowTracer(object):
|
||||||
# else, we need to chunk mix var as well
|
# else, we need to chunk mix var as well
|
||||||
else:
|
else:
|
||||||
# TODO chunk another value
|
# TODO chunk another value
|
||||||
flow_flag = True
|
flow_block = True
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("%s not implemented" % node.name)
|
raise NotImplementedError("%s not implemented" % node.name)
|
||||||
|
|
||||||
|
if flow_block:
|
||||||
|
flow_block = True
|
||||||
|
return flow_block, chunk_info
|
||||||
|
|
||||||
inputs_dim = []
|
inputs_dim = []
|
||||||
remove_inputs = []
|
remove_inputs = []
|
||||||
for input_node in chunk_info["inputs"]:
|
for input_node in chunk_info["inputs"]:
|
||||||
|
@ -250,7 +259,7 @@ class FlowTracer(object):
|
||||||
if i not in chunk_info["inputs"]:
|
if i not in chunk_info["inputs"]:
|
||||||
chunk_info["inputs_non_chunk"].append(i)
|
chunk_info["inputs_non_chunk"].append(i)
|
||||||
|
|
||||||
return flow_flag, chunk_info
|
return flow_block, chunk_info
|
||||||
|
|
||||||
|
|
||||||
class IndexTracer(object):
|
class IndexTracer(object):
|
||||||
|
@ -869,14 +878,6 @@ class IndexTracer(object):
|
||||||
if any(start_idx <= i <= end_idx for i in end_node_compute):
|
if any(start_idx <= i <= end_idx for i in end_node_compute):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
# end_node_trace_source = end_node_trace['source'][end_dim]
|
|
||||||
# for node_idx, node_dim in end_node_trace_source.items():
|
|
||||||
# if node_idx < start_node_idx or node_idx > end_node_idx:
|
|
||||||
# continue
|
|
||||||
# compute_list = self.idx_trace_list[node_idx]['compute'][node_dim]
|
|
||||||
# if any(start_node_idx <= i <= end_node_idx for i in compute_list):
|
|
||||||
# return False
|
|
||||||
# return True
|
|
||||||
|
|
||||||
def _get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
def _get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
||||||
node_from_source = self._find_source_trace_from_node(node_from)
|
node_from_source = self._find_source_trace_from_node(node_from)
|
||||||
|
@ -1240,10 +1241,10 @@ class ChunkRegionSearch(object):
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
# detect flow meet
|
# detect flow meet
|
||||||
flow_flag, chunk_info = self.flow_tracer._detect_flow(
|
flow_block, chunk_info = self.flow_tracer._detect_flow(
|
||||||
start_idx, start_dim, end_idx, end_dim, self.index_tracer
|
start_idx, start_dim, end_idx, end_dim, self.index_tracer
|
||||||
)
|
)
|
||||||
if flow_flag:
|
if flow_block:
|
||||||
continue
|
continue
|
||||||
chunk_infos.append(chunk_info)
|
chunk_infos.append(chunk_info)
|
||||||
chunk_infos = self._check_duplicate_map(chunk_infos)
|
chunk_infos = self._check_duplicate_map(chunk_infos)
|
||||||
|
|
Loading…
Reference in New Issue