format code

pull/2364/head
oahzxl 2022-12-21 15:03:08 +08:00
parent d361d533e8
commit ded1005667
1 changed files with 122 additions and 62 deletions

View File

@ -144,7 +144,9 @@ class IndexTracer(object):
node_to_trace["source"][node_to_dim][node_from_idx] = [node_from_dim]
else:
if node_from_dim not in node_to_trace["source"][node_to_dim][node_from_idx]:
node_to_trace["source"][node_to_dim][node_from_idx].append(node_from_dim)
node_to_trace["source"][node_to_dim][node_from_idx].append(
node_from_dim
)
# update inputs source
node_to_trace["source"][node_to_dim].update(
node_from_trace["source"][node_from_dim]
@ -745,7 +747,6 @@ class IndexTracer(object):
return True
class FlowTracer(object):
def __init__(self, gm) -> None:
self.gm = gm
@ -856,7 +857,9 @@ class FlowTracer(object):
)
return self.flow_trace
def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer):
def _detect_flow(
self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer
):
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
self.node_list[start_idx : end_idx + 1]
)
@ -945,8 +948,10 @@ class FlowTracer(object):
for i in remove_inputs:
if i in chunk_info["inputs"]:
chunk_info["inputs"].remove(i)
duplicate_result, duplicate_dim = index_tracer.check_index_duplicate(chunk_info, return_dim=True)
duplicate_result, duplicate_dim = index_tracer.check_index_duplicate(
chunk_info, return_dim=True
)
# we need to log input nodes to avoid deleteing them in the loop
non_chunk_inputs = _find_chunk_all_input_nodes(
@ -958,15 +963,25 @@ class FlowTracer(object):
return flow_block, chunk_info
def _assgin_single_node_flow(self, arg_node, start_idx, end_idx,
inputs, index_tracer, cur_node_dim,
cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info,
next_node_list):
def _assgin_single_node_flow(
self,
arg_node,
start_idx,
end_idx,
inputs,
index_tracer,
cur_node_dim,
cur_node_compute,
cur_node_source,
cur_node_fix_dim,
all_node_info,
next_node_list,
):
arg_idx = _find_idx_by_name(arg_node.name, index_tracer.nodes_list)
# arg in chunk range or be inputs
if not (start_idx <= arg_idx < end_idx):
return True
# find arg dim
if cur_node_dim is not None:
# dim is computed
@ -978,7 +993,7 @@ class FlowTracer(object):
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
else:
arg_dim = None
# get fix dim
arg_fix_dim = []
if cur_node_dim is not None:
@ -986,44 +1001,52 @@ class FlowTracer(object):
fix_dim_source = cur_node_source[i]
if arg_idx in fix_dim_source:
arg_fix_dim.append(fix_dim_source[arg_idx][0])
# if already in node_info, arg dim must be same
if arg_node in all_node_info:
if all_node_info[arg_node] != arg_dim:
return False
all_node_info[arg_node]['fix_dim'] = list(set(all_node_info[arg_node]['fix_dim'] + arg_fix_dim))
all_node_info[arg_node]["fix_dim"] = list(
set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)
)
# else add it to list
else:
all_node_info[arg_node] = {'chunk_dim': arg_dim, 'fix_dim': arg_fix_dim}
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
next_node_list.append(arg_node)
return True
def flow_search(self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer):
def flow_search(
self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer
):
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
self.node_list[start_idx : end_idx + 1]
)
# only single ouput
if len(outputs) > 1:
return None
cur_node_list = [index_tracer.nodes_list[end_idx]] # start from the last node
all_node_info = {cur_node_list[0]: {'chunk_dim': end_dim, 'fix_dim': []}}
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
while len(cur_node_list) > 0:
next_node_list = []
for cur_node in cur_node_list:
# get cur node info
cur_node_chunk_dim = all_node_info[cur_node]['chunk_dim']
cur_node_fix_dim = all_node_info[cur_node]['fix_dim']
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
cur_node_idx = _find_idx_by_name(cur_node.name, index_tracer.nodes_list)
if cur_node_chunk_dim:
cur_node_compute = index_tracer._find_compute_trace_from_node(cur_node)
cur_node_source = index_tracer._find_source_trace_from_node(cur_node)
cur_node_compute = index_tracer._find_compute_trace_from_node(
cur_node
)
cur_node_source = index_tracer._find_source_trace_from_node(
cur_node
)
else:
cur_node_compute = cur_node_source = None
# get all valid args
arg_list = []
for arg in cur_node.args:
@ -1032,20 +1055,33 @@ class FlowTracer(object):
if _is_non_compute_node(arg):
continue
arg_list.append(arg)
flow_flag = self._assgin_single_node_flow(arg, start_idx, end_idx,
inputs, index_tracer, cur_node_chunk_dim,
cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info,
next_node_list)
flow_flag = self._assgin_single_node_flow(
arg,
start_idx,
end_idx,
inputs,
index_tracer,
cur_node_chunk_dim,
cur_node_compute,
cur_node_source,
cur_node_fix_dim,
all_node_info,
next_node_list,
)
if flow_flag == False:
return None
if len(arg_list) == 2:
if any(i in cur_node.name for i in ["add", "mul"]):
for arg in arg_list:
if not (start_idx <= _find_idx_by_name(arg.name, index_tracer.nodes_list) < end_idx):
if not (
start_idx
<= _find_idx_by_name(arg.name, index_tracer.nodes_list)
< end_idx
):
continue
arg_chunk_dim = all_node_info[arg]['chunk_dim']
arg_fix_dim = all_node_info[arg]['fix_dim']
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
arg_fix_dim = all_node_info[arg]["fix_dim"]
arg_shape = _get_node_shape(arg)
# add all dim as fix dim except chunk dim
for i, shape in enumerate(arg_shape):
@ -1061,7 +1097,7 @@ class FlowTracer(object):
else:
raise NotImplementedError()
cur_node_list = next_node_list
inputs_dim = []
remove_inputs = []
for input_node in inputs:
@ -1071,7 +1107,7 @@ class FlowTracer(object):
continue
user_idx = _find_idx_by_name(user.name, self.node_list)
if start_idx <= user_idx <= end_idx:
chunk_dim = all_node_info[user]['chunk_dim']
chunk_dim = all_node_info[user]["chunk_dim"]
if chunk_dim is not None:
input_dict[user_idx] = chunk_dim
if len(input_dict) == 0:
@ -1081,7 +1117,7 @@ class FlowTracer(object):
for i in remove_inputs:
if i in inputs:
inputs.remove(i)
chunk_info = {
"region": (start_idx, end_idx),
"inputs": inputs,
@ -1091,7 +1127,7 @@ class FlowTracer(object):
"outputs_dim": end_dim,
"args": {},
}
# we need to log input nodes to avoid deleteing them in the loop
non_chunk_inputs = _find_chunk_all_input_nodes(
self.node_list[start_idx : end_idx + 1]
@ -1129,7 +1165,7 @@ class MemoryEstimator(object):
def _add_active_node(self, n, active_list):
new_active = self._get_output_node(n)[1]
if n.op == 'placeholder':
if n.op == "placeholder":
new_active.append(n.name)
for i in new_active:
if i not in active_list:
@ -1168,12 +1204,16 @@ class MemoryEstimator(object):
for i in delete_node:
if i in active_list:
active_list.remove(i)
def _get_chunk_inputs_size(self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx):
def _get_chunk_inputs_size(
self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx
):
nodes_to_delete = []
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
chunk_input_users = chunk_input.users.keys()
chunk_input_users_idx = [_find_idx_by_name(i.name, node_list) for i in chunk_input_users]
chunk_input_users_idx = [
_find_idx_by_name(i.name, node_list) for i in chunk_input_users
]
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
if chunk_input not in nodes_to_delete:
nodes_to_delete.append(chunk_input)
@ -1226,7 +1266,9 @@ class MemoryEstimator(object):
for (input_node, input_node_dim) in zip(chunk_inputs, chunk_inputs_dim):
for k, v in input_node_dim.items():
# TODO: inherit dim should be list too, int now
inherit_dim = self.index_tracer._find_inherit_dim(input_node, v, self.index_tracer.nodes_list[k])
inherit_dim = self.index_tracer._find_inherit_dim(
input_node, v, self.index_tracer.nodes_list[k]
)
if k == _find_idx_by_name(node.name, self.index_tracer.nodes_list):
chunk_ratio = float(chunk_size) / node_shape[inherit_dim]
return chunk_ratio
@ -1234,7 +1276,7 @@ class MemoryEstimator(object):
if k in source and inherit_dim in source[k]:
chunk_ratio = float(chunk_size) / node_shape[dim]
return chunk_ratio
return 1.
return 1.0
def _get_chunk_delete_node_size(
self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names
@ -1295,7 +1337,7 @@ class MemoryEstimator(object):
chunk_ratio = 1 # use it to estimate chunk mem
chunk_size = 1
chunk_inputs_names = []
if use_chunk:
chunk_regions = [i["region"] for i in chunk_infos]
chunk_starts = [i[0] for i in chunk_regions]
@ -1313,12 +1355,17 @@ class MemoryEstimator(object):
if use_chunk and idx in chunk_starts:
chunk_within = True
chunk_region_idx = chunk_starts.index(idx)
act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2)
act_memory += self._get_output_node_size(
chunk_outputs[chunk_region_idx]
) / (1024**2)
# determine chunk ratio for current node
if chunk_within:
chunk_ratio = self._get_chunk_ratio(
node, chunk_inputs[chunk_region_idx], chunk_inputs_dim[chunk_region_idx], chunk_size
node,
chunk_inputs[chunk_region_idx],
chunk_inputs_dim[chunk_region_idx],
chunk_size,
)
# if node is placeholder, just add the size of the node
@ -1353,18 +1400,18 @@ class MemoryEstimator(object):
/ (1024**2)
)
# delete unused vars not in chunk_input_list
# we can't delete input nodes until chunk ends
# we can't delete input nodes until chunk ends
if chunk_within:
act_memory -= self._get_chunk_delete_node_size(
node,
user_to_last_uses_no_free_var,
chunk_ratio,
chunk_inputs_names
chunk_inputs_names,
) / (1024**2)
else:
act_memory -= (self._get_delete_node_size(
act_memory -= self._get_delete_node_size(
node, user_to_last_uses_no_free_var, chunk_inputs_names
) / (1024**2))
) / (1024**2)
# log active node, only effective without chunk
self._add_active_node(node, active_node_list)
@ -1376,11 +1423,11 @@ class MemoryEstimator(object):
self._get_output_node_size(node) * chunk_ratio / (1024**2)
)
act_memory -= self._get_chunk_inputs_size(
chunk_inputs[chunk_region_idx],
chunk_inputs_non_chunk[chunk_region_idx],
chunk_inputs[chunk_region_idx],
chunk_inputs_non_chunk[chunk_region_idx],
node_list,
chunk_regions[chunk_region_idx][1]
) / (1024**2)
chunk_regions[chunk_region_idx][1],
) / (1024**2)
chunk_within = False
chunk_ratio = 1
chunk_region_idx = None
@ -1436,7 +1483,7 @@ class ChunkRegionSearch(object):
active_node_num = [len(i) for i in active_node]
min_active_node_num = min(active_node_num[free_var_num:])
threshold = max(free_var_num, min_active_node_num)
# from peak_node to free_var
inside_flag = False
chunk_region_start = free_var_num
@ -1494,7 +1541,12 @@ class ChunkRegionSearch(object):
continue
for start_node, start_trace in start_traces.items():
for start_dim, start_trace_idx in enumerate(start_trace["idx"]):
if start_idx == 199 and end_idx == 229 and start_dim == 2 and end_dim == 2:
if (
start_idx == 199
and end_idx == 229
and start_dim == 2
and end_dim == 2
):
print(1)
self.flow_tracer.flow_search(
start_idx, start_dim, end_idx, end_dim, self.index_tracer
@ -1576,7 +1628,7 @@ class ChunkRegionSearch(object):
max_region_range = 0
best_region = None
return best_region
def _is_legal_region(self, cur_chunk_info, chunk_infos):
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
if cur_chunk_info in chunk_infos:
@ -1585,11 +1637,13 @@ class ChunkRegionSearch(object):
return False
for i in chunk_infos:
region = i["region"]
if not ((chunk_region_start > region[1] and chunk_region_end > region[1])
or (chunk_region_start < region[0] and chunk_region_end < region[0])):
if not (
(chunk_region_start > region[1] and chunk_region_end > region[1])
or (chunk_region_start < region[0] and chunk_region_end < region[0])
):
return False
return True
def _step_search(self, mem_peak, active_node, chunk_regions):
peak_node = self._find_peak_node(mem_peak)
max_chunk_region = self._search_max_chunk_region(
@ -1600,7 +1654,9 @@ class ChunkRegionSearch(object):
possible_chunk_regions = self._search_possible_chunk_regions(
max_chunk_region, peak_node
)
best_chunk_region = self._search_best_chunk_region(possible_chunk_regions, chunk_regions)
best_chunk_region = self._search_best_chunk_region(
possible_chunk_regions, chunk_regions
)
return best_chunk_region
def _stop_search(self, init_mem_peak, mem_peak):
@ -1667,7 +1723,11 @@ def _gen_loop_end(
chunk_slice = _gen_chunk_slice_dim(
chunk_outputs_dim, "chunk_idx", chunk_output_shape
)
context = " chunk_result%s = %s; %s = None\n" % (chunk_slice, chunk_outputs_name, chunk_outputs_name)
context = " chunk_result%s = %s; %s = None\n" % (
chunk_slice,
chunk_outputs_name,
chunk_outputs_name,
)
context += (
chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None"
)