Browse Source

add chunk select

pull/2364/head
oahzxl 2 years ago
parent
commit
8f5a0edfab
  1. 147
      chunk_codegen.py

147
chunk_codegen.py

@ -69,7 +69,7 @@ class IndexTracer(object):
self.node_list = node_list
self.idx_trace_list = self._init_idx_trace_list()
self.idx_trace_equal = []
self.idx_view_list = []
self.idx_view_list = {}
self.idx_count = -1
self.all_reorder_map = {i: i for i in range(len(self.idx_trace_list))}
@ -576,7 +576,7 @@ class IndexTracer(object):
"idx_to": [self.idx_trace_list[node_idx]["idx"][i] for i in dim_to],
"dim_to": dim_to,
}
self.idx_view_list.append(view_dict)
self.idx_view_list[node] = view_dict
def _merge_equal_idx(self):
idx_equal = copy.deepcopy(self.idx_trace_equal)
@ -702,7 +702,7 @@ class IndexTracer(object):
for node_dim in range(len(_get_node_shape(node))):
if (
input_node_idx in node_trace_source[node_dim]
and input_dim in node_trace_source[node_dim][input_node_idx]
and input_dim[0] in node_trace_source[node_dim][input_node_idx]
):
return node_dim
return None
@ -875,6 +875,7 @@ class IndexTracer(object):
remove_inputs = []
for input_node in inputs:
input_dict = {}
input_node_idx = _find_idx_by_name(input_node.name, self.node_list)
for user in input_node.users.keys():
if _is_non_compute_node(user):
continue
@ -882,7 +883,11 @@ class IndexTracer(object):
if start_idx <= user_idx <= end_idx:
chunk_dim = all_node_info[user]["chunk_dim"]
if chunk_dim is not None:
input_dict[user_idx] = chunk_dim
user_source = self._find_source_trace_from_node(user)[chunk_dim]
if input_node_idx in user_source:
input_dict[user_idx] = user_source[input_node_idx]
else:
return None
if len(input_dict) == 0:
remove_inputs.append(input_node)
else:
@ -898,6 +903,7 @@ class IndexTracer(object):
"inputs_dim": inputs_dim,
"outputs": outputs,
"outputs_dim": end_dim,
"node_chunk_dim": all_node_info,
"args": {},
}
@ -974,6 +980,26 @@ class IndexTracer(object):
if i not in chunk_info["inputs"]:
chunk_info["inputs_non_chunk"].append(i)
# reassgin reshape size, some size may have changed due to chunk
chunk_info = self._reassgin_reshape_size(chunk_info)
return chunk_info
def _reassgin_reshape_size(self, chunk_info):
chunk_region = chunk_info['region']
reshape_size = {}
for node in self.node_list[chunk_region[0]: chunk_region[1] + 1]:
if any(i in node.name for i in ['reshape', 'view']):
reshape_args = node.args[1:]
reshape_log = self.idx_view_list[node]
chunk_dim = chunk_info['node_chunk_dim'][node]['chunk_dim']
reshape_size[node.name] = {}
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
if reshape_arg_dim in reshape_log['dim_to']:
continue
if reshape_arg_dim == chunk_dim:
reshape_size[node.name][reshape_arg.name] = "chunk_size"
chunk_info['reshape_size'] = reshape_size
return chunk_info
def _get_reorder_map(self, chunk_info):
@ -1183,23 +1209,15 @@ class MemoryEstimator(object):
not_contiguous_list.append(node)
return mem
def _get_chunk_ratio(self, node, chunk_inputs, chunk_inputs_dim, chunk_size):
def _get_chunk_ratio(self, node, chunk_node_dim, chunk_size):
if node not in chunk_node_dim:
return 1.0
node_shape = _get_node_shape(node)
node_source = self.index_tracer._find_source_trace_from_node(node)
for (input_node, input_node_dim) in zip(chunk_inputs, chunk_inputs_dim):
for k, v in input_node_dim.items():
# TODO: inherit dim should be list too, int now
inherit_dim = self.index_tracer._find_inherit_dim(
input_node, v, self.index_tracer.node_list[k]
)
if k == _find_idx_by_name(node.name, self.index_tracer.node_list):
chunk_ratio = float(chunk_size) / node_shape[inherit_dim]
return chunk_ratio
for dim, source in enumerate(node_source):
if k in source and inherit_dim in source[k]:
chunk_ratio = float(chunk_size) / node_shape[dim]
return chunk_ratio
return 1.0
chunk_dim = chunk_node_dim[node]['chunk_dim']
if chunk_dim is None:
return 1.0
else:
return float(chunk_size) / node_shape[chunk_dim]
def _get_chunk_delete_node_size(
self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names
@ -1242,6 +1260,7 @@ class MemoryEstimator(object):
self,
node_list,
chunk_infos=None,
print_mem=False,
):
act_memory = 0.0
act_memory_peak_log = []
@ -1271,6 +1290,7 @@ class MemoryEstimator(object):
j.name for i in chunk_inputs_non_chunk for j in i
]
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
for idx, node in enumerate(node_list):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
@ -1285,8 +1305,7 @@ class MemoryEstimator(object):
if chunk_within:
chunk_ratio = self._get_chunk_ratio(
node,
chunk_inputs[chunk_region_idx],
chunk_inputs_dim[chunk_region_idx],
chunk_node_dim[chunk_region_idx],
chunk_size,
)
@ -1357,11 +1376,12 @@ class MemoryEstimator(object):
act_memory_after_node_log.append(act_memory)
active_node_list_log.append(copy.deepcopy(active_node_list))
print("with chunk" if use_chunk else "without chunk")
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak")
self._print_compute_op_mem_log(act_memory_after_node_log, node_list, "after")
if print_mem:
print("with chunk" if use_chunk else "without chunk")
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak")
self._print_compute_op_mem_log(act_memory_after_node_log, node_list, "after")
# param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory
@ -1369,21 +1389,70 @@ class MemoryEstimator(object):
class ChunkSelector(object):
def __init__(self, index_tracer: IndexTracer, stratge) -> None:
def __init__(self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, stratge):
self.index_tracer = index_tracer
self.memory_estimator = memory_estimator
assert stratge in ['min_memory', 'fit_memory']
self.stratge = stratge
self.max_memory = 800 # MB
self.max_memory = 600 # MB
def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos):
def _select_best_chunk_region(self, possible_chunk_regions,
chunk_infos, peak_node, max_chunk_region, mem_peak):
if self.stratge == 'min_memory':
best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos)
elif self.stratge == 'fit_memory':
pass
best_region = self._select_fit_memory_chunk_region(
possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak)
else:
raise RuntimeError()
return best_region
def _select_fit_memory_chunk_region(self, possible_chunk_regions,
chunk_infos, peak_node, max_chunk_region, mem_peak):
# stop chunk if max memory satisfy memory limit
if max(mem_peak) < self.max_memory:
return None
# remove illegal regions
illegal_regions = []
for i in possible_chunk_regions:
if not self._is_legal_region(i, chunk_infos):
illegal_regions.append(i)
for i in illegal_regions:
if i in possible_chunk_regions:
possible_chunk_regions.remove(i)
# get mem for chunk region
regions_dict = []
for region in possible_chunk_regions:
cur_chunk_infos = chunk_infos + [region]
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
self.index_tracer.node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem_peak[max_chunk_region[0]: max_chunk_region[1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
if cur_chunk_region_max_peak < self.max_memory:
regions_dict.append({
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(region['region'][0], region['region'][1]),
})
# no region found
if len(regions_dict) == 0:
return None
# select the min chunk len
chunk_len = [i["chunk_len"] for i in regions_dict]
best_region_idx = chunk_len.index(min(chunk_len))
best_region = regions_dict[best_region_idx]["chunk_info"]
return best_region
def _get_compute_node_num(self, start, end):
count = 0
for i in self.index_tracer.node_list[start: end+1]:
if _is_non_compute_node(i):
count += 1
return count
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
max_region_range = 0
best_region = None
@ -1421,7 +1490,7 @@ class ChunkRegionSearch(object):
self.index_tracer = IndexTracer(list(gm.graph.nodes))
self.index_tracer.trace_index()
self.memory_estimator = MemoryEstimator(self.index_tracer)
self.chunk_selector = ChunkSelector(self.index_tracer, stratge="min_memory")
self.chunk_selector = ChunkSelector(self.index_tracer, self.memory_estimator, stratge="fit_memory")
def _find_peak_node(self, mem_peak):
max_value = max(mem_peak)
@ -1575,7 +1644,7 @@ class ChunkRegionSearch(object):
max_chunk_region, peak_node
)
best_chunk_region = self.chunk_selector._select_best_chunk_region(
possible_chunk_regions, chunk_regions
possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak
)
best_chunk_region = self.index_tracer.reorder_all(best_chunk_region)
return best_chunk_region
@ -1608,7 +1677,7 @@ class ChunkRegionSearch(object):
_,
active_node,
) = self.memory_estimator.estimate_chunk_inference_mem(
self.index_tracer.node_list, chunk_infos
self.index_tracer.node_list, chunk_infos, print_mem=True
)
if self._stop_search(init_mem_peak, mem_peak):
break
@ -1736,6 +1805,13 @@ def _replace_name(context, name_from, name_to):
return context
def _replace_reshape_size(context, node_name, reshape_size_dict):
if node_name not in reshape_size_dict:
return context
for size_name, size_value in reshape_size_dict[node_name].items():
context = context.replace(size_name, size_value)
return context
def emit_code_with_chunk(
body,
ckpt_func,
@ -1802,11 +1878,12 @@ def emit_code_with_chunk(
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
if idx == node_idx:
chunk_slice = _gen_chunk_slice_dim(
dim, "chunk_idx", _get_node_shape(input_node)
dim[0], "chunk_idx", _get_node_shape(input_node)
)
body[-1] = _replace_name(
body[-1], input_node.name, input_node.name + chunk_slice
)
body[-1] = _replace_reshape_size(body[-1], node.name, chunk_search[region_idx]['reshape_size'])
body[-1] = " " + body[-1]
delete_unused_value_func(node, body, chunk_inputs_names)
else:

Loading…
Cancel
Save