mirror of https://github.com/hpcaitech/ColossalAI
rename trace_index to trace_indice
parent
065f0b4c27
commit
0ea903b94e
|
@ -94,9 +94,9 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body):
|
def _replace_ones_like(search_chunk: SearchChunk, chunk_infos, region_idx, node_idx, node, body):
|
||||||
if "ones_like" in node.name:
|
if "ones_like" in node.name:
|
||||||
meta_node = search_chunk.trace_index.node_list[node_idx]
|
meta_node = search_chunk.trace_indice.node_list[node_idx]
|
||||||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
||||||
if get_node_shape(meta_node)[chunk_dim] != 1:
|
if get_node_shape(meta_node)[chunk_dim] != 1:
|
||||||
source_node = meta_node.args[0].args[0]
|
source_node = meta_node.args[0].args[0]
|
||||||
|
|
|
@ -1,22 +1,22 @@
|
||||||
from .trace_index import TraceIndex
|
from .trace_indice import TraceIndice
|
||||||
from .utils import find_idx_by_name
|
from .utils import find_idx_by_name
|
||||||
|
|
||||||
|
|
||||||
class ReorderGraph(object):
|
class ReorderGraph(object):
|
||||||
def __init__(self, trace_index: TraceIndex) -> None:
|
def __init__(self, trace_indice: TraceIndice) -> None:
|
||||||
self.trace_index = trace_index
|
self.trace_indice = trace_indice
|
||||||
self.all_reorder_map = {
|
self.all_reorder_map = {
|
||||||
i: i for i in range(len(self.trace_index.idx_trace_list))
|
i: i for i in range(len(self.trace_indice.idx_trace_list))
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_reorder_map(self, chunk_info):
|
def _get_reorder_map(self, chunk_info):
|
||||||
reorder_map = {i: i for i in range(len(self.trace_index.node_list))}
|
reorder_map = {i: i for i in range(len(self.trace_indice.node_list))}
|
||||||
|
|
||||||
chunk_region_start = chunk_info["region"][0]
|
chunk_region_start = chunk_info["region"][0]
|
||||||
chunk_region_end = chunk_info["region"][1]
|
chunk_region_end = chunk_info["region"][1]
|
||||||
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
|
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
|
||||||
chunk_prepose_nodes_idx = [
|
chunk_prepose_nodes_idx = [
|
||||||
find_idx_by_name(i.name, self.trace_index.node_list)
|
find_idx_by_name(i.name, self.trace_indice.node_list)
|
||||||
for i in chunk_prepose_nodes
|
for i in chunk_prepose_nodes
|
||||||
]
|
]
|
||||||
# put prepose nodes ahead
|
# put prepose nodes ahead
|
||||||
|
@ -24,10 +24,10 @@ class ReorderGraph(object):
|
||||||
n_idx = chunk_prepose_nodes_idx[idx]
|
n_idx = chunk_prepose_nodes_idx[idx]
|
||||||
reorder_map[n_idx] = chunk_region_start + idx
|
reorder_map[n_idx] = chunk_region_start + idx
|
||||||
# put other nodes after prepose nodes
|
# put other nodes after prepose nodes
|
||||||
for n in self.trace_index.node_list[chunk_region_start : chunk_region_end + 1]:
|
for n in self.trace_indice.node_list[chunk_region_start : chunk_region_end + 1]:
|
||||||
if n in chunk_prepose_nodes:
|
if n in chunk_prepose_nodes:
|
||||||
continue
|
continue
|
||||||
n_idx = find_idx_by_name(n.name, self.trace_index.node_list)
|
n_idx = find_idx_by_name(n.name, self.trace_indice.node_list)
|
||||||
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
|
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
|
||||||
reorder_map[n_idx] = n_idx + pos
|
reorder_map[n_idx] = n_idx + pos
|
||||||
|
|
||||||
|
@ -53,25 +53,25 @@ class ReorderGraph(object):
|
||||||
self.all_reorder_map[origin_idx] = reorder_map[map_idx]
|
self.all_reorder_map[origin_idx] = reorder_map[map_idx]
|
||||||
|
|
||||||
def _reorder_self_node_list(self, reorder_map):
|
def _reorder_self_node_list(self, reorder_map):
|
||||||
new_node_list = [None for _ in range(len(self.trace_index.node_list))]
|
new_node_list = [None for _ in range(len(self.trace_indice.node_list))]
|
||||||
for old_idx, new_idx in reorder_map.items():
|
for old_idx, new_idx in reorder_map.items():
|
||||||
new_node_list[new_idx] = self.trace_index.node_list[old_idx]
|
new_node_list[new_idx] = self.trace_indice.node_list[old_idx]
|
||||||
self.trace_index.node_list = new_node_list
|
self.trace_indice.node_list = new_node_list
|
||||||
|
|
||||||
def _reorder_idx_trace(self, reorder_map):
|
def _reorder_idx_trace(self, reorder_map):
|
||||||
# reorder list
|
# reorder list
|
||||||
new_idx_trace_list = [None for _ in range(len(self.trace_index.idx_trace_list))]
|
new_idx_trace_list = [None for _ in range(len(self.trace_indice.idx_trace_list))]
|
||||||
for old_idx, new_idx in reorder_map.items():
|
for old_idx, new_idx in reorder_map.items():
|
||||||
new_idx_trace_list[new_idx] = self.trace_index.idx_trace_list[old_idx]
|
new_idx_trace_list[new_idx] = self.trace_indice.idx_trace_list[old_idx]
|
||||||
self.trace_index.idx_trace_list = new_idx_trace_list
|
self.trace_indice.idx_trace_list = new_idx_trace_list
|
||||||
# update compute
|
# update compute
|
||||||
for idx_trace in self.trace_index.idx_trace_list:
|
for idx_trace in self.trace_indice.idx_trace_list:
|
||||||
compute = idx_trace["compute"]
|
compute = idx_trace["compute"]
|
||||||
for dim_compute in compute:
|
for dim_compute in compute:
|
||||||
for idx, i in enumerate(dim_compute):
|
for idx, i in enumerate(dim_compute):
|
||||||
dim_compute[idx] = reorder_map[i]
|
dim_compute[idx] = reorder_map[i]
|
||||||
# update source
|
# update source
|
||||||
for idx_trace in self.trace_index.idx_trace_list:
|
for idx_trace in self.trace_indice.idx_trace_list:
|
||||||
source = idx_trace["source"]
|
source = idx_trace["source"]
|
||||||
for dim_idx, dim_source in enumerate(source):
|
for dim_idx, dim_source in enumerate(source):
|
||||||
new_dim_source = {}
|
new_dim_source = {}
|
||||||
|
|
|
@ -7,7 +7,7 @@ from .estimate_memory import EstimateMemory
|
||||||
from .reorder_graph import ReorderGraph
|
from .reorder_graph import ReorderGraph
|
||||||
from .select_chunk import SelectChunk
|
from .select_chunk import SelectChunk
|
||||||
from .trace_flow import TraceFlow
|
from .trace_flow import TraceFlow
|
||||||
from .trace_index import TraceIndex
|
from .trace_indice import TraceIndice
|
||||||
from .utils import (
|
from .utils import (
|
||||||
get_node_shape,
|
get_node_shape,
|
||||||
is_non_compute_node,
|
is_non_compute_node,
|
||||||
|
@ -47,13 +47,13 @@ class SearchChunk(object):
|
||||||
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
|
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
|
||||||
self.gm = gm
|
self.gm = gm
|
||||||
self.print_mem = print_mem
|
self.print_mem = print_mem
|
||||||
self.trace_index = TraceIndex(list(gm.graph.nodes))
|
self.trace_indice = TraceIndice(list(gm.graph.nodes))
|
||||||
self.trace_index.trace_index()
|
self.trace_indice.trace_index()
|
||||||
self.trace_flow = TraceFlow(self.trace_index)
|
self.trace_flow = TraceFlow(self.trace_indice)
|
||||||
self.reorder_graph = ReorderGraph(self.trace_index)
|
self.reorder_graph = ReorderGraph(self.trace_indice)
|
||||||
self.estimate_memory = EstimateMemory()
|
self.estimate_memory = EstimateMemory()
|
||||||
self.select_chunk = SelectChunk(
|
self.select_chunk = SelectChunk(
|
||||||
self.trace_index,
|
self.trace_indice,
|
||||||
self.estimate_memory,
|
self.estimate_memory,
|
||||||
self.reorder_graph,
|
self.reorder_graph,
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
|
@ -72,7 +72,7 @@ class SearchChunk(object):
|
||||||
free_var_idx (List): all indexs of free vars
|
free_var_idx (List): all indexs of free vars
|
||||||
"""
|
"""
|
||||||
free_var_idx = []
|
free_var_idx = []
|
||||||
for idx, n in enumerate(self.trace_index.node_list):
|
for idx, n in enumerate(self.trace_indice.node_list):
|
||||||
if n.op == "placeholder":
|
if n.op == "placeholder":
|
||||||
free_var_idx.append(idx)
|
free_var_idx.append(idx)
|
||||||
return free_var_idx
|
return free_var_idx
|
||||||
|
@ -156,7 +156,7 @@ class SearchChunk(object):
|
||||||
"""
|
"""
|
||||||
start_traces = input_trace[start_idx]
|
start_traces = input_trace[start_idx]
|
||||||
end_trace = output_trace[end_idx]
|
end_trace = output_trace[end_idx]
|
||||||
end_node = self.trace_index.node_list[end_idx]
|
end_node = self.trace_indice.node_list[end_idx]
|
||||||
chunk_infos = []
|
chunk_infos = []
|
||||||
for end_dim, _ in enumerate(end_trace["idx"]):
|
for end_dim, _ in enumerate(end_trace["idx"]):
|
||||||
if len(start_traces) > 1:
|
if len(start_traces) > 1:
|
||||||
|
@ -205,23 +205,23 @@ class SearchChunk(object):
|
||||||
possible_chunk_region (List)
|
possible_chunk_region (List)
|
||||||
"""
|
"""
|
||||||
possible_chunk_region = []
|
possible_chunk_region = []
|
||||||
output_trace = copy.deepcopy(self.trace_index.idx_trace_list)
|
output_trace = copy.deepcopy(self.trace_indice.idx_trace_list)
|
||||||
input_trace = [] # trace of a node's input nodes
|
input_trace = [] # trace of a node's input nodes
|
||||||
for _, n in enumerate(self.trace_index.node_list):
|
for _, n in enumerate(self.trace_indice.node_list):
|
||||||
cur_trace = {}
|
cur_trace = {}
|
||||||
for arg in n.args:
|
for arg in n.args:
|
||||||
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(
|
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(
|
||||||
arg
|
arg
|
||||||
):
|
):
|
||||||
cur_trace[arg] = self.trace_index._find_trace_from_node(arg)
|
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
|
||||||
input_trace.append(cur_trace)
|
input_trace.append(cur_trace)
|
||||||
|
|
||||||
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
||||||
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||||
# skip non compute nodes
|
# skip non compute nodes
|
||||||
if is_non_compute_node(
|
if is_non_compute_node(
|
||||||
self.trace_index.node_list[start_idx]
|
self.trace_indice.node_list[start_idx]
|
||||||
) or is_non_compute_node(self.trace_index.node_list[end_idx]):
|
) or is_non_compute_node(self.trace_indice.node_list[end_idx]):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# select free dim
|
# select free dim
|
||||||
|
@ -292,7 +292,7 @@ class SearchChunk(object):
|
||||||
_,
|
_,
|
||||||
active_node,
|
active_node,
|
||||||
) = self.estimate_memory.estimate_chunk_inference_mem(
|
) = self.estimate_memory.estimate_chunk_inference_mem(
|
||||||
self.trace_index.node_list
|
self.trace_indice.node_list
|
||||||
)
|
)
|
||||||
mem_peak = init_mem_peak
|
mem_peak = init_mem_peak
|
||||||
|
|
||||||
|
@ -307,13 +307,13 @@ class SearchChunk(object):
|
||||||
_,
|
_,
|
||||||
active_node,
|
active_node,
|
||||||
) = self.estimate_memory.estimate_chunk_inference_mem(
|
) = self.estimate_memory.estimate_chunk_inference_mem(
|
||||||
self.trace_index.node_list, chunk_infos
|
self.trace_indice.node_list, chunk_infos
|
||||||
)
|
)
|
||||||
if self._stop_search(init_mem_peak, mem_peak):
|
if self._stop_search(init_mem_peak, mem_peak):
|
||||||
break
|
break
|
||||||
if self.print_mem:
|
if self.print_mem:
|
||||||
self.print_mem = False
|
self.print_mem = False
|
||||||
self.estimate_memory.estimate_chunk_inference_mem(
|
self.estimate_memory.estimate_chunk_inference_mem(
|
||||||
self.trace_index.node_list, chunk_infos, print_mem=True
|
self.trace_indice.node_list, chunk_infos, print_mem=True
|
||||||
)
|
)
|
||||||
return chunk_infos
|
return chunk_infos
|
||||||
|
|
|
@ -1,19 +1,19 @@
|
||||||
from .estimate_memory import EstimateMemory
|
from .estimate_memory import EstimateMemory
|
||||||
from .reorder_graph import ReorderGraph
|
from .reorder_graph import ReorderGraph
|
||||||
from .trace_index import TraceIndex
|
from .trace_indice import TraceIndice
|
||||||
from .utils import is_non_compute_node
|
from .utils import is_non_compute_node
|
||||||
|
|
||||||
|
|
||||||
class SelectChunk(object):
|
class SelectChunk(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
trace_index: TraceIndex,
|
trace_indice: TraceIndice,
|
||||||
estimate_memory: EstimateMemory,
|
estimate_memory: EstimateMemory,
|
||||||
reorder_graph: ReorderGraph,
|
reorder_graph: ReorderGraph,
|
||||||
max_memory=None,
|
max_memory=None,
|
||||||
):
|
):
|
||||||
self.index_tracer = trace_index
|
self.trace_indice = trace_indice
|
||||||
self.memory_estimator = estimate_memory
|
self.estimate_memory = estimate_memory
|
||||||
self.reorder_graph = reorder_graph
|
self.reorder_graph = reorder_graph
|
||||||
if max_memory is not None:
|
if max_memory is not None:
|
||||||
self.stratge = "fit_memory"
|
self.stratge = "fit_memory"
|
||||||
|
@ -68,10 +68,10 @@ class SelectChunk(object):
|
||||||
for region in possible_chunk_regions:
|
for region in possible_chunk_regions:
|
||||||
cur_region = region.copy()
|
cur_region = region.copy()
|
||||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
|
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
|
||||||
self.index_tracer.node_list, cur_region
|
self.trace_indice.node_list, cur_region
|
||||||
)
|
)
|
||||||
cur_chunk_infos = chunk_infos + [cur_region]
|
cur_chunk_infos = chunk_infos + [cur_region]
|
||||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||||
cur_node_list, cur_chunk_infos
|
cur_node_list, cur_chunk_infos
|
||||||
)[0]
|
)[0]
|
||||||
cur_chunk_region_peak = cur_mem_peak[
|
cur_chunk_region_peak = cur_mem_peak[
|
||||||
|
@ -113,7 +113,7 @@ class SelectChunk(object):
|
||||||
chunk_size *= 2
|
chunk_size *= 2
|
||||||
reorder_chunk_info["chunk_size"] = chunk_size
|
reorder_chunk_info["chunk_size"] = chunk_size
|
||||||
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
|
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
|
||||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||||
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
||||||
)[0]
|
)[0]
|
||||||
cur_chunk_max_mem = max(
|
cur_chunk_max_mem = max(
|
||||||
|
@ -139,7 +139,7 @@ class SelectChunk(object):
|
||||||
mid = int((left + right) / 2 + 0.5)
|
mid = int((left + right) / 2 + 0.5)
|
||||||
chunk_info["chunk_size"] = mid
|
chunk_info["chunk_size"] = mid
|
||||||
cur_chunk_infos = chunk_infos + [chunk_info]
|
cur_chunk_infos = chunk_infos + [chunk_info]
|
||||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||||
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
||||||
)[0]
|
)[0]
|
||||||
cur_chunk_max_mem = max(
|
cur_chunk_max_mem = max(
|
||||||
|
@ -153,7 +153,7 @@ class SelectChunk(object):
|
||||||
|
|
||||||
def _get_compute_node_num(self, start, end):
|
def _get_compute_node_num(self, start, end):
|
||||||
count = 0
|
count = 0
|
||||||
for i in self.index_tracer.node_list[start : end + 1]:
|
for i in self.trace_indice.node_list[start : end + 1]:
|
||||||
if not is_non_compute_node(i):
|
if not is_non_compute_node(i):
|
||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
@ -178,10 +178,10 @@ class SelectChunk(object):
|
||||||
for region in possible_chunk_regions:
|
for region in possible_chunk_regions:
|
||||||
cur_region = region.copy()
|
cur_region = region.copy()
|
||||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
|
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
|
||||||
self.index_tracer.node_list, cur_region
|
self.trace_indice.node_list, cur_region
|
||||||
)
|
)
|
||||||
cur_chunk_infos = chunk_infos + [cur_region]
|
cur_chunk_infos = chunk_infos + [cur_region]
|
||||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||||
cur_node_list, cur_chunk_infos
|
cur_node_list, cur_chunk_infos
|
||||||
)[0]
|
)[0]
|
||||||
cur_chunk_region_peak = cur_mem_peak[
|
cur_chunk_region_peak = cur_mem_peak[
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from .trace_index import TraceIndex
|
from .trace_indice import TraceIndice
|
||||||
from .utils import (
|
from .utils import (
|
||||||
find_chunk_all_input_nodes,
|
find_chunk_all_input_nodes,
|
||||||
find_chunk_compute_input_and_output_nodes,
|
find_chunk_compute_input_and_output_nodes,
|
||||||
|
@ -10,8 +10,8 @@ from .utils import (
|
||||||
|
|
||||||
|
|
||||||
class TraceFlow(object):
|
class TraceFlow(object):
|
||||||
def __init__(self, trace_index: TraceIndex) -> None:
|
def __init__(self, trace_indice: TraceIndice) -> None:
|
||||||
self.trace_index = trace_index
|
self.trace_indice = trace_indice
|
||||||
|
|
||||||
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
|
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
|
||||||
"""
|
"""
|
||||||
|
@ -25,8 +25,8 @@ class TraceFlow(object):
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if check pass
|
bool: True if check pass
|
||||||
"""
|
"""
|
||||||
start_node_idx = find_idx_by_name(start_node.name, self.trace_index.node_list)
|
start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list)
|
||||||
end_node_trace = self.trace_index._find_trace_from_node(end_node)
|
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
|
||||||
end_node_trace_source = end_node_trace["source"][end_dim]
|
end_node_trace_source = end_node_trace["source"][end_dim]
|
||||||
sorted_source = sorted(
|
sorted_source = sorted(
|
||||||
end_node_trace_source.items(), key=lambda d: d[0], reverse=True
|
end_node_trace_source.items(), key=lambda d: d[0], reverse=True
|
||||||
|
@ -51,24 +51,24 @@ class TraceFlow(object):
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if check pass
|
bool: True if check pass
|
||||||
"""
|
"""
|
||||||
end_node_trace = self.trace_index._find_trace_from_node(end_node)
|
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
|
||||||
end_node_compute = end_node_trace["compute"][end_dim]
|
end_node_compute = end_node_trace["compute"][end_dim]
|
||||||
if any(start_idx <= i <= end_idx for i in end_node_compute):
|
if any(start_idx <= i <= end_idx for i in end_node_compute):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
||||||
node_from_source = self.trace_index._find_source_trace_from_node(node_from)
|
node_from_source = self.trace_indice._find_source_trace_from_node(node_from)
|
||||||
dim_source = node_from_source[node_from_dim]
|
dim_source = node_from_source[node_from_dim]
|
||||||
node_to_idx = find_idx_by_name(node_to.name, self.trace_index.node_list)
|
node_to_idx = find_idx_by_name(node_to.name, self.trace_indice.node_list)
|
||||||
for k, v in dim_source.items():
|
for k, v in dim_source.items():
|
||||||
if k == node_to_idx:
|
if k == node_to_idx:
|
||||||
return v
|
return v
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _find_inherit_dim(self, input_node, input_dim, node):
|
def _find_inherit_dim(self, input_node, input_dim, node):
|
||||||
input_node_idx = find_idx_by_name(input_node.name, self.trace_index.node_list)
|
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
|
||||||
node_trace_source = self.trace_index._find_source_trace_from_node(node)
|
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
|
||||||
for node_dim in range(len(get_node_shape(node))):
|
for node_dim in range(len(get_node_shape(node))):
|
||||||
if (
|
if (
|
||||||
input_node_idx in node_trace_source[node_dim]
|
input_node_idx in node_trace_source[node_dim]
|
||||||
|
@ -82,19 +82,19 @@ class TraceFlow(object):
|
||||||
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
||||||
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
||||||
inherit_dim = self._find_inherit_dim(
|
inherit_dim = self._find_inherit_dim(
|
||||||
input_node, v, self.trace_index.node_list[k]
|
input_node, v, self.trace_indice.node_list[k]
|
||||||
)
|
)
|
||||||
if inherit_dim:
|
if inherit_dim:
|
||||||
input_dim_after_node[k] = inherit_dim
|
input_dim_after_node[k] = inherit_dim
|
||||||
|
|
||||||
for node in self.trace_index.node_list[
|
for node in self.trace_indice.node_list[
|
||||||
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
|
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
|
||||||
]:
|
]:
|
||||||
if is_non_compute_node_except_placeholder(node):
|
if is_non_compute_node_except_placeholder(node):
|
||||||
continue
|
continue
|
||||||
count = 0
|
count = 0
|
||||||
duplicate_dims = []
|
duplicate_dims = []
|
||||||
node_trace_source = self.trace_index._find_source_trace_from_node(node)
|
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
|
||||||
for node_dim in range(len(get_node_shape(node))):
|
for node_dim in range(len(get_node_shape(node))):
|
||||||
duplicate_dim = []
|
duplicate_dim = []
|
||||||
duplicate_flag = False
|
duplicate_flag = False
|
||||||
|
@ -130,7 +130,7 @@ class TraceFlow(object):
|
||||||
all_node_info,
|
all_node_info,
|
||||||
next_node_list,
|
next_node_list,
|
||||||
):
|
):
|
||||||
arg_idx = find_idx_by_name(arg_node.name, self.trace_index.node_list)
|
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
|
||||||
# arg in chunk range or be inputs
|
# arg in chunk range or be inputs
|
||||||
if not (start_idx <= arg_idx < end_idx):
|
if not (start_idx <= arg_idx < end_idx):
|
||||||
return True
|
return True
|
||||||
|
@ -171,7 +171,7 @@ class TraceFlow(object):
|
||||||
|
|
||||||
def _get_all_node_info(self, end_dim, start_idx, end_idx):
|
def _get_all_node_info(self, end_dim, start_idx, end_idx):
|
||||||
cur_node_list = [
|
cur_node_list = [
|
||||||
self.trace_index.node_list[end_idx]
|
self.trace_indice.node_list[end_idx]
|
||||||
] # start from the last node
|
] # start from the last node
|
||||||
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||||
|
|
||||||
|
@ -183,10 +183,10 @@ class TraceFlow(object):
|
||||||
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
||||||
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
||||||
if cur_node_chunk_dim:
|
if cur_node_chunk_dim:
|
||||||
cur_node_compute = self.trace_index._find_compute_trace_from_node(
|
cur_node_compute = self.trace_indice._find_compute_trace_from_node(
|
||||||
cur_node
|
cur_node
|
||||||
)
|
)
|
||||||
cur_node_source = self.trace_index._find_source_trace_from_node(
|
cur_node_source = self.trace_indice._find_source_trace_from_node(
|
||||||
cur_node
|
cur_node
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -220,7 +220,7 @@ class TraceFlow(object):
|
||||||
if not (
|
if not (
|
||||||
start_idx
|
start_idx
|
||||||
<= find_idx_by_name(
|
<= find_idx_by_name(
|
||||||
arg.name, self.trace_index.node_list
|
arg.name, self.trace_indice.node_list
|
||||||
)
|
)
|
||||||
< end_idx
|
< end_idx
|
||||||
):
|
):
|
||||||
|
@ -250,16 +250,16 @@ class TraceFlow(object):
|
||||||
for input_node in inputs:
|
for input_node in inputs:
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
input_node_idx = find_idx_by_name(
|
input_node_idx = find_idx_by_name(
|
||||||
input_node.name, self.trace_index.node_list
|
input_node.name, self.trace_indice.node_list
|
||||||
)
|
)
|
||||||
for user in input_node.users.keys():
|
for user in input_node.users.keys():
|
||||||
if is_non_compute_node(user):
|
if is_non_compute_node(user):
|
||||||
continue
|
continue
|
||||||
user_idx = find_idx_by_name(user.name, self.trace_index.node_list)
|
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
|
||||||
if start_idx <= user_idx <= end_idx:
|
if start_idx <= user_idx <= end_idx:
|
||||||
chunk_dim = all_node_info[user]["chunk_dim"]
|
chunk_dim = all_node_info[user]["chunk_dim"]
|
||||||
if chunk_dim is not None:
|
if chunk_dim is not None:
|
||||||
user_source = self.trace_index._find_source_trace_from_node(
|
user_source = self.trace_indice._find_source_trace_from_node(
|
||||||
user
|
user
|
||||||
)[chunk_dim]
|
)[chunk_dim]
|
||||||
if input_node_idx in user_source:
|
if input_node_idx in user_source:
|
||||||
|
@ -282,7 +282,7 @@ class TraceFlow(object):
|
||||||
if node_info["chunk_dim"] is None:
|
if node_info["chunk_dim"] is None:
|
||||||
maybe_prepose_nodes.append(node)
|
maybe_prepose_nodes.append(node)
|
||||||
maybe_prepose_nodes.sort(
|
maybe_prepose_nodes.sort(
|
||||||
key=lambda x: find_idx_by_name(x.name, self.trace_index.node_list),
|
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list),
|
||||||
reverse=True,
|
reverse=True,
|
||||||
) # from last node to first node
|
) # from last node to first node
|
||||||
prepose_nodes = []
|
prepose_nodes = []
|
||||||
|
@ -308,7 +308,7 @@ class TraceFlow(object):
|
||||||
if not (
|
if not (
|
||||||
start_idx
|
start_idx
|
||||||
<= find_idx_by_name(
|
<= find_idx_by_name(
|
||||||
cur_prepose_node_arg.name, self.trace_index.node_list
|
cur_prepose_node_arg.name, self.trace_indice.node_list
|
||||||
)
|
)
|
||||||
< end_idx
|
< end_idx
|
||||||
):
|
):
|
||||||
|
@ -336,14 +336,14 @@ class TraceFlow(object):
|
||||||
maybe_prepose_nodes.remove(n)
|
maybe_prepose_nodes.remove(n)
|
||||||
# sort by index
|
# sort by index
|
||||||
prepose_nodes.sort(
|
prepose_nodes.sort(
|
||||||
key=lambda x: find_idx_by_name(x.name, self.trace_index.node_list)
|
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list)
|
||||||
)
|
)
|
||||||
|
|
||||||
return prepose_nodes
|
return prepose_nodes
|
||||||
|
|
||||||
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
|
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
|
||||||
# we need to log input nodes to avoid deleteing them in the loop
|
# we need to log input nodes to avoid deleteing them in the loop
|
||||||
chunk_node_list = self.trace_index.node_list[start_idx : end_idx + 1]
|
chunk_node_list = self.trace_indice.node_list[start_idx : end_idx + 1]
|
||||||
# also need to get some prepose node's arg out of non_chunk_inputs
|
# also need to get some prepose node's arg out of non_chunk_inputs
|
||||||
for n in chunk_info["args"]["prepose_nodes"]:
|
for n in chunk_info["args"]["prepose_nodes"]:
|
||||||
chunk_node_list.remove(n)
|
chunk_node_list.remove(n)
|
||||||
|
@ -355,7 +355,7 @@ class TraceFlow(object):
|
||||||
|
|
||||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||||
self.trace_index.node_list[start_idx : end_idx + 1]
|
self.trace_indice.node_list[start_idx : end_idx + 1]
|
||||||
)
|
)
|
||||||
# only single ouput
|
# only single ouput
|
||||||
if len(outputs) > 1:
|
if len(outputs) > 1:
|
||||||
|
@ -403,10 +403,10 @@ class TraceFlow(object):
|
||||||
chunk_shape = get_node_shape(chunk_info["outputs"][0])[
|
chunk_shape = get_node_shape(chunk_info["outputs"][0])[
|
||||||
chunk_info["outputs_dim"]
|
chunk_info["outputs_dim"]
|
||||||
]
|
]
|
||||||
for node in self.trace_index.node_list[chunk_region[0] : chunk_region[1] + 1]:
|
for node in self.trace_indice.node_list[chunk_region[0] : chunk_region[1] + 1]:
|
||||||
if any(i in node.name for i in ["reshape", "view"]):
|
if any(i in node.name for i in ["reshape", "view"]):
|
||||||
reshape_args = node.args[1:]
|
reshape_args = node.args[1:]
|
||||||
reshape_log = self.trace_index.idx_view_list[node]
|
reshape_log = self.trace_indice.idx_view_list[node]
|
||||||
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
||||||
reshape_size[node.name] = {}
|
reshape_size[node.name] = {}
|
||||||
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
|
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
|
||||||
|
|
|
@ -6,7 +6,7 @@ from .utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TraceIndex(object):
|
class TraceIndice(object):
|
||||||
def __init__(self, node_list) -> None:
|
def __init__(self, node_list) -> None:
|
||||||
self.node_list = node_list
|
self.node_list = node_list
|
||||||
self.idx_trace_list = self._init_idx_trace_list()
|
self.idx_trace_list = self._init_idx_trace_list()
|
Loading…
Reference in New Issue