diff --git a/chunk_codegen.py b/chunk_codegen.py index cdd0b1077..330f3dec6 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -69,7 +69,7 @@ class IndexTracer(object): self.node_list = node_list self.idx_trace_list = self._init_idx_trace_list() self.idx_trace_equal = [] - self.idx_view_list = [] + self.idx_view_list = {} self.idx_count = -1 self.all_reorder_map = {i: i for i in range(len(self.idx_trace_list))} @@ -576,7 +576,7 @@ class IndexTracer(object): "idx_to": [self.idx_trace_list[node_idx]["idx"][i] for i in dim_to], "dim_to": dim_to, } - self.idx_view_list.append(view_dict) + self.idx_view_list[node] = view_dict def _merge_equal_idx(self): idx_equal = copy.deepcopy(self.idx_trace_equal) @@ -702,7 +702,7 @@ class IndexTracer(object): for node_dim in range(len(_get_node_shape(node))): if ( input_node_idx in node_trace_source[node_dim] - and input_dim in node_trace_source[node_dim][input_node_idx] + and input_dim[0] in node_trace_source[node_dim][input_node_idx] ): return node_dim return None @@ -875,6 +875,7 @@ class IndexTracer(object): remove_inputs = [] for input_node in inputs: input_dict = {} + input_node_idx = _find_idx_by_name(input_node.name, self.node_list) for user in input_node.users.keys(): if _is_non_compute_node(user): continue @@ -882,7 +883,11 @@ class IndexTracer(object): if start_idx <= user_idx <= end_idx: chunk_dim = all_node_info[user]["chunk_dim"] if chunk_dim is not None: - input_dict[user_idx] = chunk_dim + user_source = self._find_source_trace_from_node(user)[chunk_dim] + if input_node_idx in user_source: + input_dict[user_idx] = user_source[input_node_idx] + else: + return None if len(input_dict) == 0: remove_inputs.append(input_node) else: @@ -898,6 +903,7 @@ class IndexTracer(object): "inputs_dim": inputs_dim, "outputs": outputs, "outputs_dim": end_dim, + "node_chunk_dim": all_node_info, "args": {}, } @@ -974,6 +980,26 @@ class IndexTracer(object): if i not in chunk_info["inputs"]: chunk_info["inputs_non_chunk"].append(i) + # reassgin reshape size, some size may have changed due to chunk + chunk_info = self._reassgin_reshape_size(chunk_info) + + return chunk_info + + def _reassgin_reshape_size(self, chunk_info): + chunk_region = chunk_info['region'] + reshape_size = {} + for node in self.node_list[chunk_region[0]: chunk_region[1] + 1]: + if any(i in node.name for i in ['reshape', 'view']): + reshape_args = node.args[1:] + reshape_log = self.idx_view_list[node] + chunk_dim = chunk_info['node_chunk_dim'][node]['chunk_dim'] + reshape_size[node.name] = {} + for reshape_arg_dim, reshape_arg in enumerate(reshape_args): + if reshape_arg_dim in reshape_log['dim_to']: + continue + if reshape_arg_dim == chunk_dim: + reshape_size[node.name][reshape_arg.name] = "chunk_size" + chunk_info['reshape_size'] = reshape_size return chunk_info def _get_reorder_map(self, chunk_info): @@ -1183,23 +1209,15 @@ class MemoryEstimator(object): not_contiguous_list.append(node) return mem - def _get_chunk_ratio(self, node, chunk_inputs, chunk_inputs_dim, chunk_size): + def _get_chunk_ratio(self, node, chunk_node_dim, chunk_size): + if node not in chunk_node_dim: + return 1.0 node_shape = _get_node_shape(node) - node_source = self.index_tracer._find_source_trace_from_node(node) - for (input_node, input_node_dim) in zip(chunk_inputs, chunk_inputs_dim): - for k, v in input_node_dim.items(): - # TODO: inherit dim should be list too, int now - inherit_dim = self.index_tracer._find_inherit_dim( - input_node, v, self.index_tracer.node_list[k] - ) - if k == _find_idx_by_name(node.name, self.index_tracer.node_list): - chunk_ratio = float(chunk_size) / node_shape[inherit_dim] - return chunk_ratio - for dim, source in enumerate(node_source): - if k in source and inherit_dim in source[k]: - chunk_ratio = float(chunk_size) / node_shape[dim] - return chunk_ratio - return 1.0 + chunk_dim = chunk_node_dim[node]['chunk_dim'] + if chunk_dim is None: + return 1.0 + else: + return float(chunk_size) / node_shape[chunk_dim] def _get_chunk_delete_node_size( self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names @@ -1242,6 +1260,7 @@ class MemoryEstimator(object): self, node_list, chunk_infos=None, + print_mem=False, ): act_memory = 0.0 act_memory_peak_log = [] @@ -1271,6 +1290,7 @@ class MemoryEstimator(object): j.name for i in chunk_inputs_non_chunk for j in i ] chunk_outputs = [i["outputs"][0] for i in chunk_infos] + chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos] for idx, node in enumerate(node_list): # if node in chunk start nodes, change chunk ratio and add chunk_tensor @@ -1285,8 +1305,7 @@ class MemoryEstimator(object): if chunk_within: chunk_ratio = self._get_chunk_ratio( node, - chunk_inputs[chunk_region_idx], - chunk_inputs_dim[chunk_region_idx], + chunk_node_dim[chunk_region_idx], chunk_size, ) @@ -1357,11 +1376,12 @@ class MemoryEstimator(object): act_memory_after_node_log.append(act_memory) active_node_list_log.append(copy.deepcopy(active_node_list)) - print("with chunk" if use_chunk else "without chunk") - # self._print_mem_log(act_memory_peak_log, node_list, "peak") - # self._print_mem_log(act_memory_after_node_log, node_list, "after") - self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak") - self._print_compute_op_mem_log(act_memory_after_node_log, node_list, "after") + if print_mem: + print("with chunk" if use_chunk else "without chunk") + # self._print_mem_log(act_memory_peak_log, node_list, "peak") + # self._print_mem_log(act_memory_after_node_log, node_list, "after") + self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak") + self._print_compute_op_mem_log(act_memory_after_node_log, node_list, "after") # param_memory = parameter_size(gm) # all_memory = act_memory + param_memory @@ -1369,21 +1389,70 @@ class MemoryEstimator(object): class ChunkSelector(object): - def __init__(self, index_tracer: IndexTracer, stratge) -> None: + def __init__(self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, stratge): self.index_tracer = index_tracer + self.memory_estimator = memory_estimator assert stratge in ['min_memory', 'fit_memory'] self.stratge = stratge - self.max_memory = 800 # MB + self.max_memory = 600 # MB - def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos): + def _select_best_chunk_region(self, possible_chunk_regions, + chunk_infos, peak_node, max_chunk_region, mem_peak): if self.stratge == 'min_memory': best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos) elif self.stratge == 'fit_memory': - pass + best_region = self._select_fit_memory_chunk_region( + possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak) else: raise RuntimeError() return best_region + def _select_fit_memory_chunk_region(self, possible_chunk_regions, + chunk_infos, peak_node, max_chunk_region, mem_peak): + # stop chunk if max memory satisfy memory limit + if max(mem_peak) < self.max_memory: + return None + + # remove illegal regions + illegal_regions = [] + for i in possible_chunk_regions: + if not self._is_legal_region(i, chunk_infos): + illegal_regions.append(i) + for i in illegal_regions: + if i in possible_chunk_regions: + possible_chunk_regions.remove(i) + + # get mem for chunk region + regions_dict = [] + for region in possible_chunk_regions: + cur_chunk_infos = chunk_infos + [region] + cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list, cur_chunk_infos)[0] + cur_chunk_region_peak = cur_mem_peak[max_chunk_region[0]: max_chunk_region[1] + 1] + cur_chunk_region_max_peak = max(cur_chunk_region_peak) + if cur_chunk_region_max_peak < self.max_memory: + regions_dict.append({ + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num(region['region'][0], region['region'][1]), + }) + # no region found + if len(regions_dict) == 0: + return None + + # select the min chunk len + chunk_len = [i["chunk_len"] for i in regions_dict] + best_region_idx = chunk_len.index(min(chunk_len)) + best_region = regions_dict[best_region_idx]["chunk_info"] + return best_region + + def _get_compute_node_num(self, start, end): + count = 0 + for i in self.index_tracer.node_list[start: end+1]: + if _is_non_compute_node(i): + count += 1 + return count + def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos): max_region_range = 0 best_region = None @@ -1421,7 +1490,7 @@ class ChunkRegionSearch(object): self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer.trace_index() self.memory_estimator = MemoryEstimator(self.index_tracer) - self.chunk_selector = ChunkSelector(self.index_tracer, stratge="min_memory") + self.chunk_selector = ChunkSelector(self.index_tracer, self.memory_estimator, stratge="fit_memory") def _find_peak_node(self, mem_peak): max_value = max(mem_peak) @@ -1575,7 +1644,7 @@ class ChunkRegionSearch(object): max_chunk_region, peak_node ) best_chunk_region = self.chunk_selector._select_best_chunk_region( - possible_chunk_regions, chunk_regions + possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak ) best_chunk_region = self.index_tracer.reorder_all(best_chunk_region) return best_chunk_region @@ -1608,7 +1677,7 @@ class ChunkRegionSearch(object): _, active_node, ) = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, chunk_infos + self.index_tracer.node_list, chunk_infos, print_mem=True ) if self._stop_search(init_mem_peak, mem_peak): break @@ -1736,6 +1805,13 @@ def _replace_name(context, name_from, name_to): return context +def _replace_reshape_size(context, node_name, reshape_size_dict): + if node_name not in reshape_size_dict: + return context + for size_name, size_value in reshape_size_dict[node_name].items(): + context = context.replace(size_name, size_value) + return context + def emit_code_with_chunk( body, ckpt_func, @@ -1802,11 +1878,12 @@ def emit_code_with_chunk( for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): if idx == node_idx: chunk_slice = _gen_chunk_slice_dim( - dim, "chunk_idx", _get_node_shape(input_node) + dim[0], "chunk_idx", _get_node_shape(input_node) ) body[-1] = _replace_name( body[-1], input_node.name, input_node.name + chunk_slice ) + body[-1] = _replace_reshape_size(body[-1], node.name, chunk_search[region_idx]['reshape_size']) body[-1] = " " + body[-1] delete_unused_value_func(node, body, chunk_inputs_names) else: