add possible region search

pull/2364/head
oahzxl 2022-12-04 17:05:28 +08:00
parent d9ca2f898d
commit 7330d90745
1 changed files with 109 additions and 7 deletions

View File

@ -356,7 +356,17 @@ class NodeIndexTracer(object):
"idx_to": [new_trace[i] for i in dim_to],
"dim_to": dim_to}
self.idx_view_list.append(view_dict)
def _merge_equal_idx(self):
idx_equal = copy.deepcopy(self.idx_trace_equal)
idx_equal.reverse()
for idx in idx_equal:
merge_to = min(idx)
merge_from = max(idx)
for trace in self.idx_trace_list:
if merge_from in trace['idx']:
trace['idx'] = [merge_to if i == merge_from else i for i in trace['idx']]
def trace_node_idx(self):
for idx, node in enumerate(self.nodes_list):
if node.op == 'placeholder':
@ -396,6 +406,7 @@ class NodeIndexTracer(object):
continue
else:
raise NotImplementedError(node.op, "op not implemented yet!")
self._merge_equal_idx()
class MemoryEstimator(object):
@ -433,6 +444,8 @@ class MemoryEstimator(object):
for i in range(len(out_node)):
if out_node[i][0] > 0:
delete_node.append(out_node[i][1][0])
elif nodes_to_delete[i].op == 'placeholder':
delete_node.append(nodes_to_delete[i].name)
return delete_size, delete_node
def _get_delete_node_size(self, user, user_to_last_uses):
@ -516,8 +529,9 @@ class MemoryEstimator(object):
active_node_list_log = []
not_contiguous_list = []
node_list = list(gm.graph.nodes)
user_to_last_uses = self._get_last_usr(list(gm.graph.nodes))
_delete_free_var_from_last_use(user_to_last_uses)
user_to_last_uses = self._get_last_usr(node_list)
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
_delete_free_var_from_last_use(user_to_last_uses_no_free_var)
use_chunk = all(i is not None for i in [start_nodes, end_nodes, chunk_dims, chunk_sizes])
chunk_within = False
@ -535,6 +549,7 @@ class MemoryEstimator(object):
if node.op == 'placeholder':
act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024 ** 2)
act_memory_peak_log.append(act_memory)
active_node_list.append(node.name)
# skip output
elif node.op == 'output':
continue
@ -549,10 +564,10 @@ class MemoryEstimator(object):
act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2)
if chunk_within:
act_memory -= self._get_chunk_delete_node_size(
node, user_to_last_uses, chunk_ratio, node_list,
node, user_to_last_uses_no_free_var, chunk_ratio, node_list,
start_nodes[chunk_region_idx], end_nodes[chunk_region_idx]) / (1024 ** 2)
else:
act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
act_memory -= self._get_delete_node_size(node, user_to_last_uses_no_free_var) / (1024 ** 2)
# log active node
self._add_active_node(node, active_node_list)
@ -572,8 +587,92 @@ class MemoryEstimator(object):
self._print_mem_log(act_memory_peak_log, node_list, "peak")
self._print_mem_log(act_memory_after_node_log, node_list, "after")
param_memory = parameter_size(gm)
return act_memory + param_memory, param_memory
# param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory
return act_memory_peak_log, act_memory_after_node_log, active_node_list_log
class ChunkRegionSearch(object):
def __init__(self, gm) -> None:
self.gm = gm
self.node_list = list(gm.graph.nodes)
self.memory_estimator = MemoryEstimator()
self.index_tracer = NodeIndexTracer(gm)
self.index_tracer.trace_node_idx()
def _find_peak_node(self, mem_peak):
max_value = max(mem_peak)
max_idx = [mem_peak.index(max_value)]
return max_idx
def _get_free_var(self):
free_var_idx = []
for idx, n in enumerate(self.node_list):
if n.op == 'placeholder':
free_var_idx.append(idx)
return free_var_idx
def _get_min_free_var(self, active_node_list, free_vars):
min_len = 999
for idx, n in enumerate(active_node_list):
if idx in free_vars:
continue
if len(n) < min_len:
min_len = len(n)
return min_len
def _search_max_chunk_region(self, active_node, peak_node):
free_vars = self._get_free_var()
min_var = self._get_min_free_var(active_node, free_vars)
# from peak_node to free_var
chunk_region_start = None
for i in range(peak_node, -1, -1):
if len(active_node[i]) == min_var:
chunk_region_start = i + 1
break
if i in free_vars or i == 0:
raise RuntimeError()
# from peak_node to len-2
chunk_region_end = None
for i in range(peak_node, len(active_node) - 1):
if len(active_node[i]) == min_var:
chunk_region_end = i - 1
break
if i in free_vars or i == 0:
raise RuntimeError()
return chunk_region_start, chunk_region_end
def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
possible_chunk_region = []
for before_idx in range(max_chunk_region[0], peak_node):
for after_idx in range(peak_node, max_chunk_region[1]):
# skip non compute nodes
if any(op in ['placeholder', 'get_attr', 'output'] for op in
[self.node_list[before_idx].op, self.node_list[after_idx].op]):
continue
if any(any(i in name for i in ['getitem', 'getattr']) for name in
[self.node_list[before_idx].name, self.node_list[after_idx].name]):
continue
# select free dim
before_trace = self.index_tracer.idx_trace_list[before_idx]
after_trace = self.index_tracer.idx_trace_list[after_idx]
free_dim = []
for i in range(min(len(before_trace['idx']), len(after_trace['idx']))):
if (before_trace['idx'][i] == after_trace['idx'][i] and
before_trace['idx'][i] not in before_trace['compute'] and
after_trace['idx'][i] not in after_trace['compute']):
free_dim.append(i)
possible_chunk_region.append({'region': (before_idx, after_idx), 'dim': free_dim})
return possible_chunk_region
def search_region(self):
mem_peak, mem_after, active_node = self.memory_estimator.estimate_chunk_inference_mem(self.gm)
peak_nodes = self._find_peak_node(mem_peak)
for idx, peak_node in enumerate(peak_nodes):
max_chunk_region = self._search_max_chunk_region(active_node, peak_node)
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
@ -696,6 +795,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
node_index_tracer = NodeIndexTracer(meta_graph)
node_index_tracer.trace_node_idx()
chunk_region_search = ChunkRegionSearch(meta_graph)
chunk_region_search.search_region()
# find the input and output var names for each offload region
for idx, (start, end) in enumerate(chunk_regions):