mirror of https://github.com/hpcaitech/ColossalAI
seperate trace flow
parent
4748967fb1
commit
a6cdbf9161
|
@ -167,7 +167,7 @@ def emit_code_with_chunk(
|
||||||
)
|
)
|
||||||
# ones like
|
# ones like
|
||||||
if "ones_like" in node.name:
|
if "ones_like" in node.name:
|
||||||
meta_node = chunk_region_search.index_tracer.node_list[node_idx]
|
meta_node = chunk_region_search.trace_index.node_list[node_idx]
|
||||||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][
|
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][
|
||||||
"chunk_dim"
|
"chunk_dim"
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
from .select_chunk import SelectChunk
|
from .select_chunk import SelectChunk
|
||||||
from .trace_index import TraceIndex, ReorderGraph
|
from .trace_index import TraceIndex
|
||||||
|
from .reorder_graph import ReorderGraph
|
||||||
from .estiamte_memory import EstimateMemory
|
from .estiamte_memory import EstimateMemory
|
||||||
|
from .trace_flow import TraceFlow
|
||||||
from .utils import (
|
from .utils import (
|
||||||
get_node_shape,
|
get_node_shape,
|
||||||
is_non_compute_node,
|
is_non_compute_node,
|
||||||
|
@ -14,12 +16,13 @@ class SearchChunk(object):
|
||||||
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
|
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
|
||||||
self.gm = gm
|
self.gm = gm
|
||||||
self.print_mem = print_mem
|
self.print_mem = print_mem
|
||||||
self.index_tracer = TraceIndex(list(gm.graph.nodes))
|
self.trace_index = TraceIndex(list(gm.graph.nodes))
|
||||||
self.index_tracer.trace_index()
|
self.trace_index.trace_index()
|
||||||
self.reorder_graph = ReorderGraph(self.index_tracer)
|
self.trace_flow = TraceFlow(self.trace_index)
|
||||||
self.memory_estimator = EstimateMemory()
|
self.reorder_graph = ReorderGraph(self.trace_index)
|
||||||
self.chunk_selector = SelectChunk(
|
self.estimate_memory = EstimateMemory()
|
||||||
self.index_tracer, self.memory_estimator, self.reorder_graph, max_memory=max_memory
|
self.select_chunk = SelectChunk(
|
||||||
|
self.trace_index, self.estimate_memory, self.reorder_graph, max_memory=max_memory
|
||||||
)
|
)
|
||||||
|
|
||||||
def _find_peak_node(self, mem_peak):
|
def _find_peak_node(self, mem_peak):
|
||||||
|
@ -29,7 +32,7 @@ class SearchChunk(object):
|
||||||
|
|
||||||
def _get_free_var(self):
|
def _get_free_var(self):
|
||||||
free_var_idx = []
|
free_var_idx = []
|
||||||
for idx, n in enumerate(self.index_tracer.node_list):
|
for idx, n in enumerate(self.trace_index.node_list):
|
||||||
if n.op == "placeholder":
|
if n.op == "placeholder":
|
||||||
free_var_idx.append(idx)
|
free_var_idx.append(idx)
|
||||||
return free_var_idx
|
return free_var_idx
|
||||||
|
@ -99,7 +102,7 @@ class SearchChunk(object):
|
||||||
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
|
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
|
||||||
start_traces = input_trace[start_idx]
|
start_traces = input_trace[start_idx]
|
||||||
end_trace = output_trace[end_idx]
|
end_trace = output_trace[end_idx]
|
||||||
end_node = self.index_tracer.node_list[end_idx]
|
end_node = self.trace_index.node_list[end_idx]
|
||||||
chunk_infos = []
|
chunk_infos = []
|
||||||
for end_dim, _ in enumerate(end_trace["idx"]):
|
for end_dim, _ in enumerate(end_trace["idx"]):
|
||||||
if len(start_traces) > 1:
|
if len(start_traces) > 1:
|
||||||
|
@ -113,46 +116,46 @@ class SearchChunk(object):
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
# check index source align
|
# check index source align
|
||||||
if not self.index_tracer.check_index_source(
|
if not self.trace_flow.check_index_source(
|
||||||
start_dim, start_node, start_idx, end_dim, end_node
|
start_dim, start_node, start_idx, end_dim, end_node
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
# check index copmute
|
# check index copmute
|
||||||
if not self.index_tracer.check_index_compute(
|
if not self.trace_flow.check_index_compute(
|
||||||
start_idx, end_dim, end_node, end_idx
|
start_idx, end_dim, end_node, end_idx
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
# flow search
|
# flow search
|
||||||
chunk_info = self.index_tracer.flow_search(
|
chunk_info = self.trace_flow.flow_search(
|
||||||
start_idx, start_dim, end_idx, end_dim
|
start_idx, start_dim, end_idx, end_dim
|
||||||
)
|
)
|
||||||
if chunk_info is None:
|
if chunk_info is None:
|
||||||
continue
|
continue
|
||||||
# check index copmute
|
# check index copmute
|
||||||
if not self.index_tracer.check_index_duplicate(chunk_info):
|
if not self.trace_flow.check_index_duplicate(chunk_info):
|
||||||
continue
|
continue
|
||||||
chunk_infos.append(chunk_info)
|
chunk_infos.append(chunk_info)
|
||||||
return chunk_infos
|
return chunk_infos
|
||||||
|
|
||||||
def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
|
def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
|
||||||
possible_chunk_region = []
|
possible_chunk_region = []
|
||||||
output_trace = copy.deepcopy(self.index_tracer.idx_trace_list)
|
output_trace = copy.deepcopy(self.trace_index.idx_trace_list)
|
||||||
input_trace = [] # trace of a node's input nodes
|
input_trace = [] # trace of a node's input nodes
|
||||||
for _, n in enumerate(self.index_tracer.node_list):
|
for _, n in enumerate(self.trace_index.node_list):
|
||||||
cur_trace = {}
|
cur_trace = {}
|
||||||
for arg in n.args:
|
for arg in n.args:
|
||||||
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(
|
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(
|
||||||
arg
|
arg
|
||||||
):
|
):
|
||||||
cur_trace[arg] = self.index_tracer._find_trace_from_node(arg)
|
cur_trace[arg] = self.trace_index._find_trace_from_node(arg)
|
||||||
input_trace.append(cur_trace)
|
input_trace.append(cur_trace)
|
||||||
|
|
||||||
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
||||||
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||||
# skip non compute nodes
|
# skip non compute nodes
|
||||||
if is_non_compute_node(
|
if is_non_compute_node(
|
||||||
self.index_tracer.node_list[start_idx]
|
self.trace_index.node_list[start_idx]
|
||||||
) or is_non_compute_node(self.index_tracer.node_list[end_idx]):
|
) or is_non_compute_node(self.trace_index.node_list[end_idx]):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# select free dim
|
# select free dim
|
||||||
|
@ -173,7 +176,7 @@ class SearchChunk(object):
|
||||||
possible_chunk_regions = self._search_possible_chunk_regions(
|
possible_chunk_regions = self._search_possible_chunk_regions(
|
||||||
max_chunk_region, peak_node
|
max_chunk_region, peak_node
|
||||||
)
|
)
|
||||||
best_chunk_region = self.chunk_selector._select_best_chunk_region(
|
best_chunk_region = self.select_chunk._select_best_chunk_region(
|
||||||
possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak
|
possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak
|
||||||
)
|
)
|
||||||
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
|
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
|
||||||
|
@ -191,8 +194,8 @@ class SearchChunk(object):
|
||||||
init_mem_peak,
|
init_mem_peak,
|
||||||
_,
|
_,
|
||||||
active_node,
|
active_node,
|
||||||
) = self.memory_estimator.estimate_chunk_inference_mem(
|
) = self.estimate_memory.estimate_chunk_inference_mem(
|
||||||
self.index_tracer.node_list
|
self.trace_index.node_list
|
||||||
)
|
)
|
||||||
mem_peak = init_mem_peak
|
mem_peak = init_mem_peak
|
||||||
|
|
||||||
|
@ -206,14 +209,14 @@ class SearchChunk(object):
|
||||||
mem_peak,
|
mem_peak,
|
||||||
_,
|
_,
|
||||||
active_node,
|
active_node,
|
||||||
) = self.memory_estimator.estimate_chunk_inference_mem(
|
) = self.estimate_memory.estimate_chunk_inference_mem(
|
||||||
self.index_tracer.node_list, chunk_infos
|
self.trace_index.node_list, chunk_infos
|
||||||
)
|
)
|
||||||
if self._stop_search(init_mem_peak, mem_peak):
|
if self._stop_search(init_mem_peak, mem_peak):
|
||||||
break
|
break
|
||||||
if self.print_mem:
|
if self.print_mem:
|
||||||
self.print_mem = False
|
self.print_mem = False
|
||||||
self.memory_estimator.estimate_chunk_inference_mem(
|
self.estimate_memory.estimate_chunk_inference_mem(
|
||||||
self.index_tracer.node_list, chunk_infos, print_mem=True
|
self.trace_index.node_list, chunk_infos, print_mem=True
|
||||||
)
|
)
|
||||||
return chunk_infos
|
return chunk_infos
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from .trace_index import TraceIndex, ReorderGraph
|
from .trace_index import TraceIndex
|
||||||
|
from .reorder_graph import ReorderGraph
|
||||||
from .estiamte_memory import EstimateMemory
|
from .estiamte_memory import EstimateMemory
|
||||||
from .utils import is_non_compute_node
|
from .utils import is_non_compute_node
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,414 @@
|
||||||
|
from .trace_index import TraceIndex
|
||||||
|
from .utils import (
|
||||||
|
find_chunk_all_input_nodes,
|
||||||
|
find_chunk_compute_input_and_output_nodes,
|
||||||
|
find_idx_by_name,
|
||||||
|
get_node_shape,
|
||||||
|
is_non_compute_node,
|
||||||
|
is_non_compute_node_except_placeholder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TraceFlow(object):
|
||||||
|
def __init__(self, trace_index: TraceIndex) -> None:
|
||||||
|
self.trace_index = trace_index
|
||||||
|
|
||||||
|
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
|
||||||
|
"""
|
||||||
|
Check 2 given index: one index should be source of the other
|
||||||
|
Args:
|
||||||
|
start_idx(int): start node chunk dim
|
||||||
|
start_node(node): start node
|
||||||
|
end_idx(int): end node chunk dim
|
||||||
|
end_node(node): end node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if check pass
|
||||||
|
"""
|
||||||
|
start_node_idx = find_idx_by_name(start_node.name, self.trace_index.node_list)
|
||||||
|
end_node_trace = self.trace_index._find_trace_from_node(end_node)
|
||||||
|
end_node_trace_source = end_node_trace["source"][end_dim]
|
||||||
|
sorted_source = sorted(
|
||||||
|
end_node_trace_source.items(), key=lambda d: d[0], reverse=True
|
||||||
|
)
|
||||||
|
for node_idx, node_dim in sorted_source:
|
||||||
|
if node_idx == start_node_idx and start_dim in node_dim:
|
||||||
|
return True
|
||||||
|
# it means we meet a node outside the loop, and the node is not input node
|
||||||
|
if node_idx < start_idx:
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
def check_index_compute(self, start_idx, end_dim, end_node, end_idx):
|
||||||
|
"""
|
||||||
|
Check 2 given index: check they haven't been computed in the source trace.
|
||||||
|
Args:
|
||||||
|
start_idx(int): start node chunk dim
|
||||||
|
start_node(node): start node
|
||||||
|
end_idx(int): end node chunk dim
|
||||||
|
end_node(node): end node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if check pass
|
||||||
|
"""
|
||||||
|
end_node_trace = self.trace_index._find_trace_from_node(end_node)
|
||||||
|
end_node_compute = end_node_trace["compute"][end_dim]
|
||||||
|
if any(start_idx <= i <= end_idx for i in end_node_compute):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
||||||
|
node_from_source = self.trace_index._find_source_trace_from_node(node_from)
|
||||||
|
dim_source = node_from_source[node_from_dim]
|
||||||
|
node_to_idx = find_idx_by_name(node_to.name, self.trace_index.node_list)
|
||||||
|
for k, v in dim_source.items():
|
||||||
|
if k == node_to_idx:
|
||||||
|
return v
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _find_inherit_dim(self, input_node, input_dim, node):
|
||||||
|
input_node_idx = find_idx_by_name(input_node.name, self.trace_index.node_list)
|
||||||
|
node_trace_source = self.trace_index._find_source_trace_from_node(node)
|
||||||
|
for node_dim in range(len(get_node_shape(node))):
|
||||||
|
if (
|
||||||
|
input_node_idx in node_trace_source[node_dim]
|
||||||
|
and input_dim[0] in node_trace_source[node_dim][input_node_idx]
|
||||||
|
):
|
||||||
|
return node_dim
|
||||||
|
return None
|
||||||
|
|
||||||
|
def check_index_duplicate(self, chunk_infos, return_dim=False):
|
||||||
|
input_dim_after_node = {}
|
||||||
|
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
||||||
|
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
||||||
|
inherit_dim = self._find_inherit_dim(input_node, v, self.trace_index.node_list[k])
|
||||||
|
if inherit_dim:
|
||||||
|
input_dim_after_node[k] = inherit_dim
|
||||||
|
|
||||||
|
for node in self.trace_index.node_list[
|
||||||
|
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
|
||||||
|
]:
|
||||||
|
if is_non_compute_node_except_placeholder(node):
|
||||||
|
continue
|
||||||
|
count = 0
|
||||||
|
duplicate_dims = []
|
||||||
|
node_trace_source = self.trace_index._find_source_trace_from_node(node)
|
||||||
|
for node_dim in range(len(get_node_shape(node))):
|
||||||
|
duplicate_dim = []
|
||||||
|
duplicate_flag = False
|
||||||
|
dim_source = node_trace_source[node_dim]
|
||||||
|
for k, v in dim_source.items():
|
||||||
|
if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]:
|
||||||
|
if k in input_dim_after_node and input_dim_after_node[k] in v:
|
||||||
|
duplicate_flag = True
|
||||||
|
duplicate_dim.append((k, v))
|
||||||
|
duplicate_dims.append(duplicate_dim)
|
||||||
|
if duplicate_flag:
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
if count > 1:
|
||||||
|
if return_dim:
|
||||||
|
return False, duplicate_dims
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
if return_dim:
|
||||||
|
return True, None
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _assgin_single_node_flow(
|
||||||
|
self,
|
||||||
|
arg_node,
|
||||||
|
start_idx,
|
||||||
|
end_idx,
|
||||||
|
cur_node_dim,
|
||||||
|
cur_node_compute,
|
||||||
|
cur_node_source,
|
||||||
|
cur_node_fix_dim,
|
||||||
|
all_node_info,
|
||||||
|
next_node_list,
|
||||||
|
):
|
||||||
|
arg_idx = find_idx_by_name(arg_node.name, self.trace_index.node_list)
|
||||||
|
# arg in chunk range or be inputs
|
||||||
|
if not (start_idx <= arg_idx < end_idx):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# find arg dim
|
||||||
|
if cur_node_dim is not None:
|
||||||
|
# dim is computed
|
||||||
|
if arg_idx in cur_node_compute[cur_node_dim]:
|
||||||
|
return False
|
||||||
|
if arg_idx not in cur_node_source[cur_node_dim]:
|
||||||
|
arg_dim = None
|
||||||
|
else:
|
||||||
|
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
|
||||||
|
else:
|
||||||
|
arg_dim = None
|
||||||
|
|
||||||
|
# get fix dim
|
||||||
|
arg_fix_dim = []
|
||||||
|
if cur_node_dim is not None:
|
||||||
|
for i in cur_node_fix_dim:
|
||||||
|
fix_dim_source = cur_node_source[i]
|
||||||
|
if arg_idx in fix_dim_source:
|
||||||
|
arg_fix_dim.append(fix_dim_source[arg_idx][0])
|
||||||
|
|
||||||
|
# if already in node_info, arg dim must be same
|
||||||
|
if arg_node in all_node_info:
|
||||||
|
if all_node_info[arg_node]["chunk_dim"] != arg_dim:
|
||||||
|
return False
|
||||||
|
all_node_info[arg_node]["fix_dim"] = list(
|
||||||
|
set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)
|
||||||
|
)
|
||||||
|
# else add it to list
|
||||||
|
else:
|
||||||
|
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
|
||||||
|
|
||||||
|
next_node_list.append(arg_node)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _get_all_node_info(self, end_dim, start_idx, end_idx):
|
||||||
|
cur_node_list = [
|
||||||
|
self.trace_index.node_list[end_idx]
|
||||||
|
] # start from the last node
|
||||||
|
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||||
|
|
||||||
|
while len(cur_node_list) > 0:
|
||||||
|
next_node_list = []
|
||||||
|
|
||||||
|
for cur_node in cur_node_list:
|
||||||
|
# get cur node info
|
||||||
|
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
||||||
|
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
||||||
|
if cur_node_chunk_dim:
|
||||||
|
cur_node_compute = self.trace_index._find_compute_trace_from_node(
|
||||||
|
cur_node
|
||||||
|
)
|
||||||
|
cur_node_source = self.trace_index._find_source_trace_from_node(
|
||||||
|
cur_node
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cur_node_compute = cur_node_source = None
|
||||||
|
|
||||||
|
# get all valid args
|
||||||
|
arg_list = []
|
||||||
|
for arg in cur_node.args:
|
||||||
|
if type(arg) != type(cur_node):
|
||||||
|
continue
|
||||||
|
if is_non_compute_node(arg):
|
||||||
|
continue
|
||||||
|
arg_list.append(arg)
|
||||||
|
flow_flag = self._assgin_single_node_flow(
|
||||||
|
arg,
|
||||||
|
start_idx,
|
||||||
|
end_idx,
|
||||||
|
cur_node_chunk_dim,
|
||||||
|
cur_node_compute,
|
||||||
|
cur_node_source,
|
||||||
|
cur_node_fix_dim,
|
||||||
|
all_node_info,
|
||||||
|
next_node_list,
|
||||||
|
)
|
||||||
|
if flow_flag == False:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(arg_list) == 2:
|
||||||
|
if any(i in cur_node.name for i in ["add", "mul"]):
|
||||||
|
for arg in arg_list:
|
||||||
|
if not (
|
||||||
|
start_idx
|
||||||
|
<= find_idx_by_name(arg.name, self.trace_index.node_list)
|
||||||
|
< end_idx
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
|
||||||
|
arg_fix_dim = all_node_info[arg]["fix_dim"]
|
||||||
|
arg_shape = get_node_shape(arg)
|
||||||
|
# add all dim as fix dim except chunk dim
|
||||||
|
for i, shape in enumerate(arg_shape):
|
||||||
|
if shape != 1 and i != cur_node_chunk_dim:
|
||||||
|
if i == arg_chunk_dim:
|
||||||
|
return None
|
||||||
|
if i not in arg_fix_dim:
|
||||||
|
arg_fix_dim.append(i)
|
||||||
|
elif "einsum" in cur_node.name:
|
||||||
|
pass
|
||||||
|
elif "matmul" in cur_node.name:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
cur_node_list = next_node_list
|
||||||
|
return all_node_info
|
||||||
|
|
||||||
|
def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info):
|
||||||
|
inputs_dim = []
|
||||||
|
remove_inputs = []
|
||||||
|
for input_node in inputs:
|
||||||
|
input_dict = {}
|
||||||
|
input_node_idx = find_idx_by_name(
|
||||||
|
input_node.name, self.trace_index.node_list
|
||||||
|
)
|
||||||
|
for user in input_node.users.keys():
|
||||||
|
if is_non_compute_node(user):
|
||||||
|
continue
|
||||||
|
user_idx = find_idx_by_name(user.name, self.trace_index.node_list)
|
||||||
|
if start_idx <= user_idx <= end_idx:
|
||||||
|
chunk_dim = all_node_info[user]["chunk_dim"]
|
||||||
|
if chunk_dim is not None:
|
||||||
|
user_source = self.trace_index._find_source_trace_from_node(user)[chunk_dim]
|
||||||
|
if input_node_idx in user_source:
|
||||||
|
input_dict[user_idx] = user_source[input_node_idx]
|
||||||
|
else:
|
||||||
|
return None, None
|
||||||
|
if len(input_dict) == 0:
|
||||||
|
remove_inputs.append(input_node)
|
||||||
|
else:
|
||||||
|
inputs_dim.append(input_dict)
|
||||||
|
for i in remove_inputs:
|
||||||
|
if i in inputs:
|
||||||
|
inputs.remove(i)
|
||||||
|
return inputs, inputs_dim
|
||||||
|
|
||||||
|
def _get_prepose_nodes(self, all_node_info, start_idx, end_idx):
|
||||||
|
# get all possible prepose nodes
|
||||||
|
maybe_prepose_nodes = []
|
||||||
|
for node, node_info in all_node_info.items():
|
||||||
|
if node_info["chunk_dim"] is None:
|
||||||
|
maybe_prepose_nodes.append(node)
|
||||||
|
maybe_prepose_nodes.sort(
|
||||||
|
key=lambda x: find_idx_by_name(x.name, self.trace_index.node_list),
|
||||||
|
reverse=True,
|
||||||
|
) # from last node to first node
|
||||||
|
prepose_nodes = []
|
||||||
|
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
|
||||||
|
while len(maybe_prepose_nodes) > 0:
|
||||||
|
tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]]
|
||||||
|
tmp_cur_related_prepose_nodes = []
|
||||||
|
prepose_flag = True
|
||||||
|
|
||||||
|
# loop cur node's all arg until out of chunk
|
||||||
|
while len(tmp_cur_prepose_nodes) > 0:
|
||||||
|
if prepose_flag == False:
|
||||||
|
break
|
||||||
|
tmp_next_prepose_nodes = []
|
||||||
|
tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes)
|
||||||
|
for cur_prepose_node in tmp_cur_prepose_nodes:
|
||||||
|
if prepose_flag == False:
|
||||||
|
break
|
||||||
|
for cur_prepose_node_arg in cur_prepose_node.args:
|
||||||
|
if type(cur_prepose_node_arg) != type(cur_prepose_node):
|
||||||
|
continue
|
||||||
|
# out of loop
|
||||||
|
if not (
|
||||||
|
start_idx
|
||||||
|
<= find_idx_by_name(
|
||||||
|
cur_prepose_node_arg.name, self.trace_index.node_list
|
||||||
|
)
|
||||||
|
< end_idx
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
# compute op in loop
|
||||||
|
elif cur_prepose_node_arg in all_node_info:
|
||||||
|
if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None:
|
||||||
|
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
|
||||||
|
else:
|
||||||
|
prepose_flag = False
|
||||||
|
break
|
||||||
|
# non compute op
|
||||||
|
else:
|
||||||
|
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
|
||||||
|
tmp_cur_prepose_nodes = tmp_next_prepose_nodes
|
||||||
|
|
||||||
|
if prepose_flag == False:
|
||||||
|
maybe_prepose_nodes.remove(maybe_prepose_nodes[0])
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
for n in tmp_cur_related_prepose_nodes:
|
||||||
|
if n not in prepose_nodes:
|
||||||
|
prepose_nodes.append(n)
|
||||||
|
if n in maybe_prepose_nodes:
|
||||||
|
maybe_prepose_nodes.remove(n)
|
||||||
|
# sort by index
|
||||||
|
prepose_nodes.sort(
|
||||||
|
key=lambda x: find_idx_by_name(x.name, self.trace_index.node_list)
|
||||||
|
)
|
||||||
|
|
||||||
|
return prepose_nodes
|
||||||
|
|
||||||
|
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
|
||||||
|
# we need to log input nodes to avoid deleteing them in the loop
|
||||||
|
chunk_node_list = self.trace_index.node_list[start_idx : end_idx + 1]
|
||||||
|
# also need to get some prepose node's arg out of non_chunk_inputs
|
||||||
|
for n in chunk_info["args"]["prepose_nodes"]:
|
||||||
|
chunk_node_list.remove(n)
|
||||||
|
non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list)
|
||||||
|
for i in non_chunk_inputs:
|
||||||
|
if i not in chunk_info["inputs"]:
|
||||||
|
chunk_info["inputs_non_chunk"].append(i)
|
||||||
|
return chunk_info
|
||||||
|
|
||||||
|
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||||
|
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||||
|
self.trace_index.node_list[start_idx : end_idx + 1]
|
||||||
|
)
|
||||||
|
# only single ouput
|
||||||
|
if len(outputs) > 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# get every node's chunk dim and fix dim
|
||||||
|
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
|
||||||
|
if all_node_info is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# get input nodes' chunk dim
|
||||||
|
inputs, inputs_dim = self._get_input_nodes_dim(
|
||||||
|
inputs, start_idx, end_idx, all_node_info
|
||||||
|
)
|
||||||
|
if inputs is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
chunk_info = {
|
||||||
|
"region": (start_idx, end_idx),
|
||||||
|
"inputs": inputs,
|
||||||
|
"inputs_non_chunk": [],
|
||||||
|
"inputs_dim": inputs_dim,
|
||||||
|
"outputs": outputs,
|
||||||
|
"outputs_dim": end_dim,
|
||||||
|
"node_chunk_dim": all_node_info,
|
||||||
|
"args": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
# move useless nodes ahead of loop
|
||||||
|
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(
|
||||||
|
all_node_info, start_idx, end_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
# find non chunk inputs
|
||||||
|
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
|
||||||
|
|
||||||
|
# reassgin reshape size, some size may have changed due to chunk
|
||||||
|
chunk_info = self._reassgin_reshape_size(chunk_info)
|
||||||
|
|
||||||
|
return chunk_info
|
||||||
|
|
||||||
|
def _reassgin_reshape_size(self, chunk_info):
|
||||||
|
chunk_region = chunk_info["region"]
|
||||||
|
reshape_size = {}
|
||||||
|
chunk_shape = get_node_shape(chunk_info["outputs"][0])[
|
||||||
|
chunk_info["outputs_dim"]
|
||||||
|
]
|
||||||
|
for node in self.trace_index.node_list[chunk_region[0] : chunk_region[1] + 1]:
|
||||||
|
if any(i in node.name for i in ["reshape", "view"]):
|
||||||
|
reshape_args = node.args[1:]
|
||||||
|
reshape_log = self.trace_index.idx_view_list[node]
|
||||||
|
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
||||||
|
reshape_size[node.name] = {}
|
||||||
|
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
|
||||||
|
if reshape_arg_dim in reshape_log["dim_to"]:
|
||||||
|
continue
|
||||||
|
if reshape_arg_dim == chunk_dim:
|
||||||
|
reshape_size[node.name][reshape_arg.name] = (
|
||||||
|
"min(chunk_size, %d - chunk_idx)" % chunk_shape
|
||||||
|
)
|
||||||
|
chunk_info["reshape_size"] = reshape_size
|
||||||
|
return chunk_info
|
|
@ -1,12 +1,8 @@
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
find_chunk_all_input_nodes,
|
|
||||||
find_chunk_compute_input_and_output_nodes,
|
|
||||||
find_idx_by_name,
|
find_idx_by_name,
|
||||||
get_node_shape,
|
get_node_shape,
|
||||||
is_non_compute_node,
|
|
||||||
is_non_compute_node_except_placeholder,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -588,394 +584,3 @@ class TraceIndex(object):
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(node.op, "op not implemented yet!")
|
raise NotImplementedError(node.op, "op not implemented yet!")
|
||||||
# self._merge_equal_idx()
|
|
||||||
|
|
||||||
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
|
|
||||||
"""
|
|
||||||
Check 2 given index: one index should be source of the other
|
|
||||||
Args:
|
|
||||||
start_idx(int): start node chunk dim
|
|
||||||
start_node(node): start node
|
|
||||||
end_idx(int): end node chunk dim
|
|
||||||
end_node(node): end node
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if check pass
|
|
||||||
"""
|
|
||||||
start_node_idx = find_idx_by_name(start_node.name, self.node_list)
|
|
||||||
end_node_trace = self._find_trace_from_node(end_node)
|
|
||||||
end_node_trace_source = end_node_trace["source"][end_dim]
|
|
||||||
sorted_source = sorted(
|
|
||||||
end_node_trace_source.items(), key=lambda d: d[0], reverse=True
|
|
||||||
)
|
|
||||||
for node_idx, node_dim in sorted_source:
|
|
||||||
if node_idx == start_node_idx and start_dim in node_dim:
|
|
||||||
return True
|
|
||||||
# it means we meet a node outside the loop, and the node is not input node
|
|
||||||
if node_idx < start_idx:
|
|
||||||
return False
|
|
||||||
return False
|
|
||||||
|
|
||||||
def check_index_compute(self, start_idx, end_dim, end_node, end_idx):
|
|
||||||
"""
|
|
||||||
Check 2 given index: check they haven't been computed in the source trace.
|
|
||||||
Args:
|
|
||||||
start_idx(int): start node chunk dim
|
|
||||||
start_node(node): start node
|
|
||||||
end_idx(int): end node chunk dim
|
|
||||||
end_node(node): end node
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if check pass
|
|
||||||
"""
|
|
||||||
end_node_trace = self._find_trace_from_node(end_node)
|
|
||||||
end_node_compute = end_node_trace["compute"][end_dim]
|
|
||||||
if any(start_idx <= i <= end_idx for i in end_node_compute):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
|
||||||
node_from_source = self._find_source_trace_from_node(node_from)
|
|
||||||
dim_source = node_from_source[node_from_dim]
|
|
||||||
node_to_idx = find_idx_by_name(node_to.name, self.node_list)
|
|
||||||
for k, v in dim_source.items():
|
|
||||||
if k == node_to_idx:
|
|
||||||
return v
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _find_inherit_dim(self, input_node, input_dim, node):
|
|
||||||
input_node_idx = find_idx_by_name(input_node.name, self.node_list)
|
|
||||||
node_trace_source = self._find_source_trace_from_node(node)
|
|
||||||
for node_dim in range(len(get_node_shape(node))):
|
|
||||||
if (
|
|
||||||
input_node_idx in node_trace_source[node_dim]
|
|
||||||
and input_dim[0] in node_trace_source[node_dim][input_node_idx]
|
|
||||||
):
|
|
||||||
return node_dim
|
|
||||||
return None
|
|
||||||
|
|
||||||
def check_index_duplicate(self, chunk_infos, return_dim=False):
|
|
||||||
input_dim_after_node = {}
|
|
||||||
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
|
||||||
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
|
||||||
inherit_dim = self._find_inherit_dim(input_node, v, self.node_list[k])
|
|
||||||
if inherit_dim:
|
|
||||||
input_dim_after_node[k] = inherit_dim
|
|
||||||
|
|
||||||
for node in self.node_list[
|
|
||||||
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
|
|
||||||
]:
|
|
||||||
if is_non_compute_node_except_placeholder(node):
|
|
||||||
continue
|
|
||||||
count = 0
|
|
||||||
duplicate_dims = []
|
|
||||||
node_trace_source = self._find_source_trace_from_node(node)
|
|
||||||
for node_dim in range(len(get_node_shape(node))):
|
|
||||||
duplicate_dim = []
|
|
||||||
duplicate_flag = False
|
|
||||||
dim_source = node_trace_source[node_dim]
|
|
||||||
for k, v in dim_source.items():
|
|
||||||
if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]:
|
|
||||||
if k in input_dim_after_node and input_dim_after_node[k] in v:
|
|
||||||
duplicate_flag = True
|
|
||||||
duplicate_dim.append((k, v))
|
|
||||||
duplicate_dims.append(duplicate_dim)
|
|
||||||
if duplicate_flag:
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
if count > 1:
|
|
||||||
if return_dim:
|
|
||||||
return False, duplicate_dims
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
if return_dim:
|
|
||||||
return True, None
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _assgin_single_node_flow(
|
|
||||||
self,
|
|
||||||
arg_node,
|
|
||||||
start_idx,
|
|
||||||
end_idx,
|
|
||||||
cur_node_dim,
|
|
||||||
cur_node_compute,
|
|
||||||
cur_node_source,
|
|
||||||
cur_node_fix_dim,
|
|
||||||
all_node_info,
|
|
||||||
next_node_list,
|
|
||||||
):
|
|
||||||
arg_idx = find_idx_by_name(arg_node.name, self.node_list)
|
|
||||||
# arg in chunk range or be inputs
|
|
||||||
if not (start_idx <= arg_idx < end_idx):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# find arg dim
|
|
||||||
if cur_node_dim is not None:
|
|
||||||
# dim is computed
|
|
||||||
if arg_idx in cur_node_compute[cur_node_dim]:
|
|
||||||
return False
|
|
||||||
if arg_idx not in cur_node_source[cur_node_dim]:
|
|
||||||
arg_dim = None
|
|
||||||
else:
|
|
||||||
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
|
|
||||||
else:
|
|
||||||
arg_dim = None
|
|
||||||
|
|
||||||
# get fix dim
|
|
||||||
arg_fix_dim = []
|
|
||||||
if cur_node_dim is not None:
|
|
||||||
for i in cur_node_fix_dim:
|
|
||||||
fix_dim_source = cur_node_source[i]
|
|
||||||
if arg_idx in fix_dim_source:
|
|
||||||
arg_fix_dim.append(fix_dim_source[arg_idx][0])
|
|
||||||
|
|
||||||
# if already in node_info, arg dim must be same
|
|
||||||
if arg_node in all_node_info:
|
|
||||||
if all_node_info[arg_node]["chunk_dim"] != arg_dim:
|
|
||||||
return False
|
|
||||||
all_node_info[arg_node]["fix_dim"] = list(
|
|
||||||
set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)
|
|
||||||
)
|
|
||||||
# else add it to list
|
|
||||||
else:
|
|
||||||
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
|
|
||||||
|
|
||||||
next_node_list.append(arg_node)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _get_all_node_info(self, end_dim, start_idx, end_idx):
|
|
||||||
cur_node_list = [self.node_list[end_idx]] # start from the last node
|
|
||||||
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
|
||||||
|
|
||||||
while len(cur_node_list) > 0:
|
|
||||||
next_node_list = []
|
|
||||||
|
|
||||||
for cur_node in cur_node_list:
|
|
||||||
# get cur node info
|
|
||||||
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
|
||||||
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
|
||||||
if cur_node_chunk_dim:
|
|
||||||
cur_node_compute = self._find_compute_trace_from_node(cur_node)
|
|
||||||
cur_node_source = self._find_source_trace_from_node(cur_node)
|
|
||||||
else:
|
|
||||||
cur_node_compute = cur_node_source = None
|
|
||||||
|
|
||||||
# get all valid args
|
|
||||||
arg_list = []
|
|
||||||
for arg in cur_node.args:
|
|
||||||
if type(arg) != type(cur_node):
|
|
||||||
continue
|
|
||||||
if is_non_compute_node(arg):
|
|
||||||
continue
|
|
||||||
arg_list.append(arg)
|
|
||||||
flow_flag = self._assgin_single_node_flow(
|
|
||||||
arg,
|
|
||||||
start_idx,
|
|
||||||
end_idx,
|
|
||||||
cur_node_chunk_dim,
|
|
||||||
cur_node_compute,
|
|
||||||
cur_node_source,
|
|
||||||
cur_node_fix_dim,
|
|
||||||
all_node_info,
|
|
||||||
next_node_list,
|
|
||||||
)
|
|
||||||
if flow_flag == False:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if len(arg_list) == 2:
|
|
||||||
if any(i in cur_node.name for i in ["add", "mul"]):
|
|
||||||
for arg in arg_list:
|
|
||||||
if not (
|
|
||||||
start_idx
|
|
||||||
<= find_idx_by_name(arg.name, self.node_list)
|
|
||||||
< end_idx
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
|
|
||||||
arg_fix_dim = all_node_info[arg]["fix_dim"]
|
|
||||||
arg_shape = get_node_shape(arg)
|
|
||||||
# add all dim as fix dim except chunk dim
|
|
||||||
for i, shape in enumerate(arg_shape):
|
|
||||||
if shape != 1 and i != cur_node_chunk_dim:
|
|
||||||
if i == arg_chunk_dim:
|
|
||||||
return None
|
|
||||||
if i not in arg_fix_dim:
|
|
||||||
arg_fix_dim.append(i)
|
|
||||||
elif "einsum" in cur_node.name:
|
|
||||||
pass
|
|
||||||
elif "matmul" in cur_node.name:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
cur_node_list = next_node_list
|
|
||||||
return all_node_info
|
|
||||||
|
|
||||||
def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info):
|
|
||||||
inputs_dim = []
|
|
||||||
remove_inputs = []
|
|
||||||
for input_node in inputs:
|
|
||||||
input_dict = {}
|
|
||||||
input_node_idx = find_idx_by_name(input_node.name, self.node_list)
|
|
||||||
for user in input_node.users.keys():
|
|
||||||
if is_non_compute_node(user):
|
|
||||||
continue
|
|
||||||
user_idx = find_idx_by_name(user.name, self.node_list)
|
|
||||||
if start_idx <= user_idx <= end_idx:
|
|
||||||
chunk_dim = all_node_info[user]["chunk_dim"]
|
|
||||||
if chunk_dim is not None:
|
|
||||||
user_source = self._find_source_trace_from_node(user)[chunk_dim]
|
|
||||||
if input_node_idx in user_source:
|
|
||||||
input_dict[user_idx] = user_source[input_node_idx]
|
|
||||||
else:
|
|
||||||
return None, None
|
|
||||||
if len(input_dict) == 0:
|
|
||||||
remove_inputs.append(input_node)
|
|
||||||
else:
|
|
||||||
inputs_dim.append(input_dict)
|
|
||||||
for i in remove_inputs:
|
|
||||||
if i in inputs:
|
|
||||||
inputs.remove(i)
|
|
||||||
return inputs, inputs_dim
|
|
||||||
|
|
||||||
def _get_prepose_nodes(self, all_node_info, start_idx, end_idx):
|
|
||||||
# get all possible prepose nodes
|
|
||||||
maybe_prepose_nodes = []
|
|
||||||
for node, node_info in all_node_info.items():
|
|
||||||
if node_info["chunk_dim"] is None:
|
|
||||||
maybe_prepose_nodes.append(node)
|
|
||||||
maybe_prepose_nodes.sort(
|
|
||||||
key=lambda x: find_idx_by_name(x.name, self.node_list),
|
|
||||||
reverse=True,
|
|
||||||
) # from last node to first node
|
|
||||||
prepose_nodes = []
|
|
||||||
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
|
|
||||||
while len(maybe_prepose_nodes) > 0:
|
|
||||||
tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]]
|
|
||||||
tmp_cur_related_prepose_nodes = []
|
|
||||||
prepose_flag = True
|
|
||||||
|
|
||||||
# loop cur node's all arg until out of chunk
|
|
||||||
while len(tmp_cur_prepose_nodes) > 0:
|
|
||||||
if prepose_flag == False:
|
|
||||||
break
|
|
||||||
tmp_next_prepose_nodes = []
|
|
||||||
tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes)
|
|
||||||
for cur_prepose_node in tmp_cur_prepose_nodes:
|
|
||||||
if prepose_flag == False:
|
|
||||||
break
|
|
||||||
for cur_prepose_node_arg in cur_prepose_node.args:
|
|
||||||
if type(cur_prepose_node_arg) != type(cur_prepose_node):
|
|
||||||
continue
|
|
||||||
# out of loop
|
|
||||||
if not (
|
|
||||||
start_idx
|
|
||||||
<= find_idx_by_name(
|
|
||||||
cur_prepose_node_arg.name, self.node_list
|
|
||||||
)
|
|
||||||
< end_idx
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
# compute op in loop
|
|
||||||
elif cur_prepose_node_arg in all_node_info:
|
|
||||||
if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None:
|
|
||||||
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
|
|
||||||
else:
|
|
||||||
prepose_flag = False
|
|
||||||
break
|
|
||||||
# non compute op
|
|
||||||
else:
|
|
||||||
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
|
|
||||||
tmp_cur_prepose_nodes = tmp_next_prepose_nodes
|
|
||||||
|
|
||||||
if prepose_flag == False:
|
|
||||||
maybe_prepose_nodes.remove(maybe_prepose_nodes[0])
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
for n in tmp_cur_related_prepose_nodes:
|
|
||||||
if n not in prepose_nodes:
|
|
||||||
prepose_nodes.append(n)
|
|
||||||
if n in maybe_prepose_nodes:
|
|
||||||
maybe_prepose_nodes.remove(n)
|
|
||||||
# sort by index
|
|
||||||
prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.node_list))
|
|
||||||
|
|
||||||
return prepose_nodes
|
|
||||||
|
|
||||||
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
|
|
||||||
# we need to log input nodes to avoid deleteing them in the loop
|
|
||||||
chunk_node_list = self.node_list[start_idx : end_idx + 1]
|
|
||||||
# also need to get some prepose node's arg out of non_chunk_inputs
|
|
||||||
for n in chunk_info["args"]["prepose_nodes"]:
|
|
||||||
chunk_node_list.remove(n)
|
|
||||||
non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list)
|
|
||||||
for i in non_chunk_inputs:
|
|
||||||
if i not in chunk_info["inputs"]:
|
|
||||||
chunk_info["inputs_non_chunk"].append(i)
|
|
||||||
return chunk_info
|
|
||||||
|
|
||||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
|
||||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
|
||||||
self.node_list[start_idx : end_idx + 1]
|
|
||||||
)
|
|
||||||
# only single ouput
|
|
||||||
if len(outputs) > 1:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# get every node's chunk dim and fix dim
|
|
||||||
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
|
|
||||||
if all_node_info is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# get input nodes' chunk dim
|
|
||||||
inputs, inputs_dim = self._get_input_nodes_dim(
|
|
||||||
inputs, start_idx, end_idx, all_node_info
|
|
||||||
)
|
|
||||||
if inputs is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
chunk_info = {
|
|
||||||
"region": (start_idx, end_idx),
|
|
||||||
"inputs": inputs,
|
|
||||||
"inputs_non_chunk": [],
|
|
||||||
"inputs_dim": inputs_dim,
|
|
||||||
"outputs": outputs,
|
|
||||||
"outputs_dim": end_dim,
|
|
||||||
"node_chunk_dim": all_node_info,
|
|
||||||
"args": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
# move useless nodes ahead of loop
|
|
||||||
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(
|
|
||||||
all_node_info, start_idx, end_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
# find non chunk inputs
|
|
||||||
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
|
|
||||||
|
|
||||||
# reassgin reshape size, some size may have changed due to chunk
|
|
||||||
chunk_info = self._reassgin_reshape_size(chunk_info)
|
|
||||||
|
|
||||||
return chunk_info
|
|
||||||
|
|
||||||
def _reassgin_reshape_size(self, chunk_info):
|
|
||||||
chunk_region = chunk_info["region"]
|
|
||||||
reshape_size = {}
|
|
||||||
chunk_shape = get_node_shape(chunk_info["outputs"][0])[
|
|
||||||
chunk_info["outputs_dim"]
|
|
||||||
]
|
|
||||||
for node in self.node_list[chunk_region[0] : chunk_region[1] + 1]:
|
|
||||||
if any(i in node.name for i in ["reshape", "view"]):
|
|
||||||
reshape_args = node.args[1:]
|
|
||||||
reshape_log = self.idx_view_list[node]
|
|
||||||
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
|
||||||
reshape_size[node.name] = {}
|
|
||||||
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
|
|
||||||
if reshape_arg_dim in reshape_log["dim_to"]:
|
|
||||||
continue
|
|
||||||
if reshape_arg_dim == chunk_dim:
|
|
||||||
reshape_size[node.name][reshape_arg.name] = (
|
|
||||||
"min(chunk_size, %d - chunk_idx)" % chunk_shape
|
|
||||||
)
|
|
||||||
chunk_info["reshape_size"] = reshape_size
|
|
||||||
return chunk_info
|
|
||||||
|
|
|
@ -104,8 +104,8 @@ def benchmark_evoformer():
|
||||||
model = evoformer_base().cuda()
|
model = evoformer_base().cuda()
|
||||||
|
|
||||||
# build autochunk model
|
# build autochunk model
|
||||||
# max_memory = 1000 # MB fit memory mode
|
max_memory = 1000 # MB fit memory mode
|
||||||
max_memory = None # min memory mode
|
# max_memory = None # min memory mode
|
||||||
autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair)
|
autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair)
|
||||||
|
|
||||||
# build openfold
|
# build openfold
|
||||||
|
|
Loading…
Reference in New Issue