rename trace_index to trace_indice

pull/2364/head
oahzxl 2023-01-09 17:25:13 +08:00
parent 065f0b4c27
commit 0ea903b94e
6 changed files with 74 additions and 74 deletions

View File

@ -94,9 +94,9 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
return context return context
def _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body): def _replace_ones_like(search_chunk: SearchChunk, chunk_infos, region_idx, node_idx, node, body):
if "ones_like" in node.name: if "ones_like" in node.name:
meta_node = search_chunk.trace_index.node_list[node_idx] meta_node = search_chunk.trace_indice.node_list[node_idx]
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if get_node_shape(meta_node)[chunk_dim] != 1: if get_node_shape(meta_node)[chunk_dim] != 1:
source_node = meta_node.args[0].args[0] source_node = meta_node.args[0].args[0]

View File

@ -1,22 +1,22 @@
from .trace_index import TraceIndex from .trace_indice import TraceIndice
from .utils import find_idx_by_name from .utils import find_idx_by_name
class ReorderGraph(object): class ReorderGraph(object):
def __init__(self, trace_index: TraceIndex) -> None: def __init__(self, trace_indice: TraceIndice) -> None:
self.trace_index = trace_index self.trace_indice = trace_indice
self.all_reorder_map = { self.all_reorder_map = {
i: i for i in range(len(self.trace_index.idx_trace_list)) i: i for i in range(len(self.trace_indice.idx_trace_list))
} }
def _get_reorder_map(self, chunk_info): def _get_reorder_map(self, chunk_info):
reorder_map = {i: i for i in range(len(self.trace_index.node_list))} reorder_map = {i: i for i in range(len(self.trace_indice.node_list))}
chunk_region_start = chunk_info["region"][0] chunk_region_start = chunk_info["region"][0]
chunk_region_end = chunk_info["region"][1] chunk_region_end = chunk_info["region"][1]
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
chunk_prepose_nodes_idx = [ chunk_prepose_nodes_idx = [
find_idx_by_name(i.name, self.trace_index.node_list) find_idx_by_name(i.name, self.trace_indice.node_list)
for i in chunk_prepose_nodes for i in chunk_prepose_nodes
] ]
# put prepose nodes ahead # put prepose nodes ahead
@ -24,10 +24,10 @@ class ReorderGraph(object):
n_idx = chunk_prepose_nodes_idx[idx] n_idx = chunk_prepose_nodes_idx[idx]
reorder_map[n_idx] = chunk_region_start + idx reorder_map[n_idx] = chunk_region_start + idx
# put other nodes after prepose nodes # put other nodes after prepose nodes
for n in self.trace_index.node_list[chunk_region_start : chunk_region_end + 1]: for n in self.trace_indice.node_list[chunk_region_start : chunk_region_end + 1]:
if n in chunk_prepose_nodes: if n in chunk_prepose_nodes:
continue continue
n_idx = find_idx_by_name(n.name, self.trace_index.node_list) n_idx = find_idx_by_name(n.name, self.trace_indice.node_list)
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
reorder_map[n_idx] = n_idx + pos reorder_map[n_idx] = n_idx + pos
@ -53,25 +53,25 @@ class ReorderGraph(object):
self.all_reorder_map[origin_idx] = reorder_map[map_idx] self.all_reorder_map[origin_idx] = reorder_map[map_idx]
def _reorder_self_node_list(self, reorder_map): def _reorder_self_node_list(self, reorder_map):
new_node_list = [None for _ in range(len(self.trace_index.node_list))] new_node_list = [None for _ in range(len(self.trace_indice.node_list))]
for old_idx, new_idx in reorder_map.items(): for old_idx, new_idx in reorder_map.items():
new_node_list[new_idx] = self.trace_index.node_list[old_idx] new_node_list[new_idx] = self.trace_indice.node_list[old_idx]
self.trace_index.node_list = new_node_list self.trace_indice.node_list = new_node_list
def _reorder_idx_trace(self, reorder_map): def _reorder_idx_trace(self, reorder_map):
# reorder list # reorder list
new_idx_trace_list = [None for _ in range(len(self.trace_index.idx_trace_list))] new_idx_trace_list = [None for _ in range(len(self.trace_indice.idx_trace_list))]
for old_idx, new_idx in reorder_map.items(): for old_idx, new_idx in reorder_map.items():
new_idx_trace_list[new_idx] = self.trace_index.idx_trace_list[old_idx] new_idx_trace_list[new_idx] = self.trace_indice.idx_trace_list[old_idx]
self.trace_index.idx_trace_list = new_idx_trace_list self.trace_indice.idx_trace_list = new_idx_trace_list
# update compute # update compute
for idx_trace in self.trace_index.idx_trace_list: for idx_trace in self.trace_indice.idx_trace_list:
compute = idx_trace["compute"] compute = idx_trace["compute"]
for dim_compute in compute: for dim_compute in compute:
for idx, i in enumerate(dim_compute): for idx, i in enumerate(dim_compute):
dim_compute[idx] = reorder_map[i] dim_compute[idx] = reorder_map[i]
# update source # update source
for idx_trace in self.trace_index.idx_trace_list: for idx_trace in self.trace_indice.idx_trace_list:
source = idx_trace["source"] source = idx_trace["source"]
for dim_idx, dim_source in enumerate(source): for dim_idx, dim_source in enumerate(source):
new_dim_source = {} new_dim_source = {}

View File

@ -7,7 +7,7 @@ from .estimate_memory import EstimateMemory
from .reorder_graph import ReorderGraph from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk from .select_chunk import SelectChunk
from .trace_flow import TraceFlow from .trace_flow import TraceFlow
from .trace_index import TraceIndex from .trace_indice import TraceIndice
from .utils import ( from .utils import (
get_node_shape, get_node_shape,
is_non_compute_node, is_non_compute_node,
@ -47,13 +47,13 @@ class SearchChunk(object):
def __init__(self, gm, max_memory=None, print_mem=False) -> None: def __init__(self, gm, max_memory=None, print_mem=False) -> None:
self.gm = gm self.gm = gm
self.print_mem = print_mem self.print_mem = print_mem
self.trace_index = TraceIndex(list(gm.graph.nodes)) self.trace_indice = TraceIndice(list(gm.graph.nodes))
self.trace_index.trace_index() self.trace_indice.trace_index()
self.trace_flow = TraceFlow(self.trace_index) self.trace_flow = TraceFlow(self.trace_indice)
self.reorder_graph = ReorderGraph(self.trace_index) self.reorder_graph = ReorderGraph(self.trace_indice)
self.estimate_memory = EstimateMemory() self.estimate_memory = EstimateMemory()
self.select_chunk = SelectChunk( self.select_chunk = SelectChunk(
self.trace_index, self.trace_indice,
self.estimate_memory, self.estimate_memory,
self.reorder_graph, self.reorder_graph,
max_memory=max_memory, max_memory=max_memory,
@ -72,7 +72,7 @@ class SearchChunk(object):
free_var_idx (List): all indexs of free vars free_var_idx (List): all indexs of free vars
""" """
free_var_idx = [] free_var_idx = []
for idx, n in enumerate(self.trace_index.node_list): for idx, n in enumerate(self.trace_indice.node_list):
if n.op == "placeholder": if n.op == "placeholder":
free_var_idx.append(idx) free_var_idx.append(idx)
return free_var_idx return free_var_idx
@ -156,7 +156,7 @@ class SearchChunk(object):
""" """
start_traces = input_trace[start_idx] start_traces = input_trace[start_idx]
end_trace = output_trace[end_idx] end_trace = output_trace[end_idx]
end_node = self.trace_index.node_list[end_idx] end_node = self.trace_indice.node_list[end_idx]
chunk_infos = [] chunk_infos = []
for end_dim, _ in enumerate(end_trace["idx"]): for end_dim, _ in enumerate(end_trace["idx"]):
if len(start_traces) > 1: if len(start_traces) > 1:
@ -205,23 +205,23 @@ class SearchChunk(object):
possible_chunk_region (List) possible_chunk_region (List)
""" """
possible_chunk_region = [] possible_chunk_region = []
output_trace = copy.deepcopy(self.trace_index.idx_trace_list) output_trace = copy.deepcopy(self.trace_indice.idx_trace_list)
input_trace = [] # trace of a node's input nodes input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.trace_index.node_list): for _, n in enumerate(self.trace_indice.node_list):
cur_trace = {} cur_trace = {}
for arg in n.args: for arg in n.args:
if type(arg) == type(n) and not is_non_compute_node_except_placeholder( if type(arg) == type(n) and not is_non_compute_node_except_placeholder(
arg arg
): ):
cur_trace[arg] = self.trace_index._find_trace_from_node(arg) cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
input_trace.append(cur_trace) input_trace.append(cur_trace)
for start_idx in range(max_chunk_region[0], peak_node + 1): for start_idx in range(max_chunk_region[0], peak_node + 1):
for end_idx in range(peak_node, max_chunk_region[1] + 1): for end_idx in range(peak_node, max_chunk_region[1] + 1):
# skip non compute nodes # skip non compute nodes
if is_non_compute_node( if is_non_compute_node(
self.trace_index.node_list[start_idx] self.trace_indice.node_list[start_idx]
) or is_non_compute_node(self.trace_index.node_list[end_idx]): ) or is_non_compute_node(self.trace_indice.node_list[end_idx]):
continue continue
# select free dim # select free dim
@ -292,7 +292,7 @@ class SearchChunk(object):
_, _,
active_node, active_node,
) = self.estimate_memory.estimate_chunk_inference_mem( ) = self.estimate_memory.estimate_chunk_inference_mem(
self.trace_index.node_list self.trace_indice.node_list
) )
mem_peak = init_mem_peak mem_peak = init_mem_peak
@ -307,13 +307,13 @@ class SearchChunk(object):
_, _,
active_node, active_node,
) = self.estimate_memory.estimate_chunk_inference_mem( ) = self.estimate_memory.estimate_chunk_inference_mem(
self.trace_index.node_list, chunk_infos self.trace_indice.node_list, chunk_infos
) )
if self._stop_search(init_mem_peak, mem_peak): if self._stop_search(init_mem_peak, mem_peak):
break break
if self.print_mem: if self.print_mem:
self.print_mem = False self.print_mem = False
self.estimate_memory.estimate_chunk_inference_mem( self.estimate_memory.estimate_chunk_inference_mem(
self.trace_index.node_list, chunk_infos, print_mem=True self.trace_indice.node_list, chunk_infos, print_mem=True
) )
return chunk_infos return chunk_infos

View File

@ -1,19 +1,19 @@
from .estimate_memory import EstimateMemory from .estimate_memory import EstimateMemory
from .reorder_graph import ReorderGraph from .reorder_graph import ReorderGraph
from .trace_index import TraceIndex from .trace_indice import TraceIndice
from .utils import is_non_compute_node from .utils import is_non_compute_node
class SelectChunk(object): class SelectChunk(object):
def __init__( def __init__(
self, self,
trace_index: TraceIndex, trace_indice: TraceIndice,
estimate_memory: EstimateMemory, estimate_memory: EstimateMemory,
reorder_graph: ReorderGraph, reorder_graph: ReorderGraph,
max_memory=None, max_memory=None,
): ):
self.index_tracer = trace_index self.trace_indice = trace_indice
self.memory_estimator = estimate_memory self.estimate_memory = estimate_memory
self.reorder_graph = reorder_graph self.reorder_graph = reorder_graph
if max_memory is not None: if max_memory is not None:
self.stratge = "fit_memory" self.stratge = "fit_memory"
@ -68,10 +68,10 @@ class SelectChunk(object):
for region in possible_chunk_regions: for region in possible_chunk_regions:
cur_region = region.copy() cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder( cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
self.index_tracer.node_list, cur_region self.trace_indice.node_list, cur_region
) )
cur_chunk_infos = chunk_infos + [cur_region] cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
cur_node_list, cur_chunk_infos cur_node_list, cur_chunk_infos
)[0] )[0]
cur_chunk_region_peak = cur_mem_peak[ cur_chunk_region_peak = cur_mem_peak[
@ -113,7 +113,7 @@ class SelectChunk(object):
chunk_size *= 2 chunk_size *= 2
reorder_chunk_info["chunk_size"] = chunk_size reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_infos = chunk_infos + [reorder_chunk_info] cur_chunk_infos = chunk_infos + [reorder_chunk_info]
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
chunk_region_dict["reorder_node_list"], cur_chunk_infos chunk_region_dict["reorder_node_list"], cur_chunk_infos
)[0] )[0]
cur_chunk_max_mem = max( cur_chunk_max_mem = max(
@ -139,7 +139,7 @@ class SelectChunk(object):
mid = int((left + right) / 2 + 0.5) mid = int((left + right) / 2 + 0.5)
chunk_info["chunk_size"] = mid chunk_info["chunk_size"] = mid
cur_chunk_infos = chunk_infos + [chunk_info] cur_chunk_infos = chunk_infos + [chunk_info]
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
chunk_region_dict["reorder_node_list"], cur_chunk_infos chunk_region_dict["reorder_node_list"], cur_chunk_infos
)[0] )[0]
cur_chunk_max_mem = max( cur_chunk_max_mem = max(
@ -153,7 +153,7 @@ class SelectChunk(object):
def _get_compute_node_num(self, start, end): def _get_compute_node_num(self, start, end):
count = 0 count = 0
for i in self.index_tracer.node_list[start : end + 1]: for i in self.trace_indice.node_list[start : end + 1]:
if not is_non_compute_node(i): if not is_non_compute_node(i):
count += 1 count += 1
return count return count
@ -178,10 +178,10 @@ class SelectChunk(object):
for region in possible_chunk_regions: for region in possible_chunk_regions:
cur_region = region.copy() cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder( cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
self.index_tracer.node_list, cur_region self.trace_indice.node_list, cur_region
) )
cur_chunk_infos = chunk_infos + [cur_region] cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
cur_node_list, cur_chunk_infos cur_node_list, cur_chunk_infos
)[0] )[0]
cur_chunk_region_peak = cur_mem_peak[ cur_chunk_region_peak = cur_mem_peak[

View File

@ -1,4 +1,4 @@
from .trace_index import TraceIndex from .trace_indice import TraceIndice
from .utils import ( from .utils import (
find_chunk_all_input_nodes, find_chunk_all_input_nodes,
find_chunk_compute_input_and_output_nodes, find_chunk_compute_input_and_output_nodes,
@ -10,8 +10,8 @@ from .utils import (
class TraceFlow(object): class TraceFlow(object):
def __init__(self, trace_index: TraceIndex) -> None: def __init__(self, trace_indice: TraceIndice) -> None:
self.trace_index = trace_index self.trace_indice = trace_indice
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node): def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
""" """
@ -25,8 +25,8 @@ class TraceFlow(object):
Returns: Returns:
bool: True if check pass bool: True if check pass
""" """
start_node_idx = find_idx_by_name(start_node.name, self.trace_index.node_list) start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list)
end_node_trace = self.trace_index._find_trace_from_node(end_node) end_node_trace = self.trace_indice._find_trace_from_node(end_node)
end_node_trace_source = end_node_trace["source"][end_dim] end_node_trace_source = end_node_trace["source"][end_dim]
sorted_source = sorted( sorted_source = sorted(
end_node_trace_source.items(), key=lambda d: d[0], reverse=True end_node_trace_source.items(), key=lambda d: d[0], reverse=True
@ -51,24 +51,24 @@ class TraceFlow(object):
Returns: Returns:
bool: True if check pass bool: True if check pass
""" """
end_node_trace = self.trace_index._find_trace_from_node(end_node) end_node_trace = self.trace_indice._find_trace_from_node(end_node)
end_node_compute = end_node_trace["compute"][end_dim] end_node_compute = end_node_trace["compute"][end_dim]
if any(start_idx <= i <= end_idx for i in end_node_compute): if any(start_idx <= i <= end_idx for i in end_node_compute):
return False return False
return True return True
def get_node_chunk_dim(self, node_from, node_from_dim, node_to): def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
node_from_source = self.trace_index._find_source_trace_from_node(node_from) node_from_source = self.trace_indice._find_source_trace_from_node(node_from)
dim_source = node_from_source[node_from_dim] dim_source = node_from_source[node_from_dim]
node_to_idx = find_idx_by_name(node_to.name, self.trace_index.node_list) node_to_idx = find_idx_by_name(node_to.name, self.trace_indice.node_list)
for k, v in dim_source.items(): for k, v in dim_source.items():
if k == node_to_idx: if k == node_to_idx:
return v return v
return None return None
def _find_inherit_dim(self, input_node, input_dim, node): def _find_inherit_dim(self, input_node, input_dim, node):
input_node_idx = find_idx_by_name(input_node.name, self.trace_index.node_list) input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
node_trace_source = self.trace_index._find_source_trace_from_node(node) node_trace_source = self.trace_indice._find_source_trace_from_node(node)
for node_dim in range(len(get_node_shape(node))): for node_dim in range(len(get_node_shape(node))):
if ( if (
input_node_idx in node_trace_source[node_dim] input_node_idx in node_trace_source[node_dim]
@ -82,19 +82,19 @@ class TraceFlow(object):
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
inherit_dim = self._find_inherit_dim( inherit_dim = self._find_inherit_dim(
input_node, v, self.trace_index.node_list[k] input_node, v, self.trace_indice.node_list[k]
) )
if inherit_dim: if inherit_dim:
input_dim_after_node[k] = inherit_dim input_dim_after_node[k] = inherit_dim
for node in self.trace_index.node_list[ for node in self.trace_indice.node_list[
chunk_infos["region"][0] : chunk_infos["region"][1] + 1 chunk_infos["region"][0] : chunk_infos["region"][1] + 1
]: ]:
if is_non_compute_node_except_placeholder(node): if is_non_compute_node_except_placeholder(node):
continue continue
count = 0 count = 0
duplicate_dims = [] duplicate_dims = []
node_trace_source = self.trace_index._find_source_trace_from_node(node) node_trace_source = self.trace_indice._find_source_trace_from_node(node)
for node_dim in range(len(get_node_shape(node))): for node_dim in range(len(get_node_shape(node))):
duplicate_dim = [] duplicate_dim = []
duplicate_flag = False duplicate_flag = False
@ -130,7 +130,7 @@ class TraceFlow(object):
all_node_info, all_node_info,
next_node_list, next_node_list,
): ):
arg_idx = find_idx_by_name(arg_node.name, self.trace_index.node_list) arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
# arg in chunk range or be inputs # arg in chunk range or be inputs
if not (start_idx <= arg_idx < end_idx): if not (start_idx <= arg_idx < end_idx):
return True return True
@ -171,7 +171,7 @@ class TraceFlow(object):
def _get_all_node_info(self, end_dim, start_idx, end_idx): def _get_all_node_info(self, end_dim, start_idx, end_idx):
cur_node_list = [ cur_node_list = [
self.trace_index.node_list[end_idx] self.trace_indice.node_list[end_idx]
] # start from the last node ] # start from the last node
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
@ -183,10 +183,10 @@ class TraceFlow(object):
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
if cur_node_chunk_dim: if cur_node_chunk_dim:
cur_node_compute = self.trace_index._find_compute_trace_from_node( cur_node_compute = self.trace_indice._find_compute_trace_from_node(
cur_node cur_node
) )
cur_node_source = self.trace_index._find_source_trace_from_node( cur_node_source = self.trace_indice._find_source_trace_from_node(
cur_node cur_node
) )
else: else:
@ -220,7 +220,7 @@ class TraceFlow(object):
if not ( if not (
start_idx start_idx
<= find_idx_by_name( <= find_idx_by_name(
arg.name, self.trace_index.node_list arg.name, self.trace_indice.node_list
) )
< end_idx < end_idx
): ):
@ -250,16 +250,16 @@ class TraceFlow(object):
for input_node in inputs: for input_node in inputs:
input_dict = {} input_dict = {}
input_node_idx = find_idx_by_name( input_node_idx = find_idx_by_name(
input_node.name, self.trace_index.node_list input_node.name, self.trace_indice.node_list
) )
for user in input_node.users.keys(): for user in input_node.users.keys():
if is_non_compute_node(user): if is_non_compute_node(user):
continue continue
user_idx = find_idx_by_name(user.name, self.trace_index.node_list) user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
if start_idx <= user_idx <= end_idx: if start_idx <= user_idx <= end_idx:
chunk_dim = all_node_info[user]["chunk_dim"] chunk_dim = all_node_info[user]["chunk_dim"]
if chunk_dim is not None: if chunk_dim is not None:
user_source = self.trace_index._find_source_trace_from_node( user_source = self.trace_indice._find_source_trace_from_node(
user user
)[chunk_dim] )[chunk_dim]
if input_node_idx in user_source: if input_node_idx in user_source:
@ -282,7 +282,7 @@ class TraceFlow(object):
if node_info["chunk_dim"] is None: if node_info["chunk_dim"] is None:
maybe_prepose_nodes.append(node) maybe_prepose_nodes.append(node)
maybe_prepose_nodes.sort( maybe_prepose_nodes.sort(
key=lambda x: find_idx_by_name(x.name, self.trace_index.node_list), key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list),
reverse=True, reverse=True,
) # from last node to first node ) # from last node to first node
prepose_nodes = [] prepose_nodes = []
@ -308,7 +308,7 @@ class TraceFlow(object):
if not ( if not (
start_idx start_idx
<= find_idx_by_name( <= find_idx_by_name(
cur_prepose_node_arg.name, self.trace_index.node_list cur_prepose_node_arg.name, self.trace_indice.node_list
) )
< end_idx < end_idx
): ):
@ -336,14 +336,14 @@ class TraceFlow(object):
maybe_prepose_nodes.remove(n) maybe_prepose_nodes.remove(n)
# sort by index # sort by index
prepose_nodes.sort( prepose_nodes.sort(
key=lambda x: find_idx_by_name(x.name, self.trace_index.node_list) key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list)
) )
return prepose_nodes return prepose_nodes
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
# we need to log input nodes to avoid deleteing them in the loop # we need to log input nodes to avoid deleteing them in the loop
chunk_node_list = self.trace_index.node_list[start_idx : end_idx + 1] chunk_node_list = self.trace_indice.node_list[start_idx : end_idx + 1]
# also need to get some prepose node's arg out of non_chunk_inputs # also need to get some prepose node's arg out of non_chunk_inputs
for n in chunk_info["args"]["prepose_nodes"]: for n in chunk_info["args"]["prepose_nodes"]:
chunk_node_list.remove(n) chunk_node_list.remove(n)
@ -355,7 +355,7 @@ class TraceFlow(object):
def flow_search(self, start_idx, start_dim, end_idx, end_dim): def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes( inputs, outputs = find_chunk_compute_input_and_output_nodes(
self.trace_index.node_list[start_idx : end_idx + 1] self.trace_indice.node_list[start_idx : end_idx + 1]
) )
# only single ouput # only single ouput
if len(outputs) > 1: if len(outputs) > 1:
@ -403,10 +403,10 @@ class TraceFlow(object):
chunk_shape = get_node_shape(chunk_info["outputs"][0])[ chunk_shape = get_node_shape(chunk_info["outputs"][0])[
chunk_info["outputs_dim"] chunk_info["outputs_dim"]
] ]
for node in self.trace_index.node_list[chunk_region[0] : chunk_region[1] + 1]: for node in self.trace_indice.node_list[chunk_region[0] : chunk_region[1] + 1]:
if any(i in node.name for i in ["reshape", "view"]): if any(i in node.name for i in ["reshape", "view"]):
reshape_args = node.args[1:] reshape_args = node.args[1:]
reshape_log = self.trace_index.idx_view_list[node] reshape_log = self.trace_indice.idx_view_list[node]
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
reshape_size[node.name] = {} reshape_size[node.name] = {}
for reshape_arg_dim, reshape_arg in enumerate(reshape_args): for reshape_arg_dim, reshape_arg in enumerate(reshape_args):

View File

@ -6,7 +6,7 @@ from .utils import (
) )
class TraceIndex(object): class TraceIndice(object):
def __init__(self, node_list) -> None: def __init__(self, node_list) -> None:
self.node_list = node_list self.node_list = node_list
self.idx_trace_list = self._init_idx_trace_list() self.idx_trace_list = self._init_idx_trace_list()