diff --git a/chunk_codegen.py b/chunk_codegen.py index ade986d1e..77aca8deb 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -438,7 +438,7 @@ class MemoryEstimator(object): def _get_delete_node_size(self, user, user_to_last_uses): return self._get_delete_node(user, user_to_last_uses)[0] - def _remove_active_node(self, user, user_to_last_uses, active_list): + def _remove_deactive_node(self, user, user_to_last_uses, active_list): delete_node = self._get_delete_node(user, user_to_last_uses)[1] for i in delete_node: active_list.remove(i) @@ -481,48 +481,6 @@ class MemoryEstimator(object): return mem - def estimate_inference_mem(self, gm: torch.fx.GraphModule): - act_memory = 0.0 - act_memory_peak_log = [] - act_memory_after_node_log = [] - not_contiguous_list = [] - active_node_list = [] - active_node_list_log = [] - user_to_last_uses = self._get_last_usr(list(gm.graph.nodes)) - _delete_free_var_from_last_use(user_to_last_uses) - for node in gm.graph.nodes: - # if node is placeholder, just add the size of the node - if node.op == 'placeholder': - act_memory += self._get_meta_node_size(node) / (1024 ** 2) - act_memory_peak_log.append(act_memory) - active_node_list.append(node.name) - # skip output - elif node.op == 'output': - continue - # node is an operation, calculate tmp, output node and delete node memory - else: - # forward memory - act_memory += self._get_contiguous_memory(node, not_contiguous_list) / (1024 ** 2) - act_memory += self._get_output_node_size(node) / (1024 ** 2) - # record max act memory - act_memory_peak_log.append(act_memory) - # delete useless memory - act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) - act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2) - # log active node - self._add_active_node(node, active_node_list) - self._remove_active_node(node, user_to_last_uses, active_node_list) - - act_memory_after_node_log.append(act_memory) - active_node_list_log.append(copy.deepcopy(active_node_list)) - print("no chunk") - self._print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak") - self._print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after") - - param_memory = parameter_size(gm) - return act_memory + param_memory, param_memory - - def _get_chunk_ratio(self, node, chunk_dim, chunk_size): shape = node.meta['tensor_meta'].shape chunk_ratio = float(chunk_size) / shape[chunk_dim] @@ -550,25 +508,28 @@ class MemoryEstimator(object): print("") print("\n") - - def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes): + def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes=None, end_nodes=None, chunk_dims=None, chunk_sizes=None): act_memory = 0.0 act_memory_peak_log = [] act_memory_after_node_log = [] + active_node_list = [] + active_node_list_log = [] not_contiguous_list = [] + node_list = list(gm.graph.nodes) user_to_last_uses = self._get_last_usr(list(gm.graph.nodes)) _delete_free_var_from_last_use(user_to_last_uses) - within_chunk = False - region_idx = 0 + + use_chunk = all(i is not None for i in [start_nodes, end_nodes, chunk_dims, chunk_sizes]) + chunk_within = False + chunk_region_idx = 0 chunk_ratio = 1 # use it to estimate chunk mem - node_list = list(gm.graph.nodes) for idx, node in enumerate(node_list): # if node in chunk start nodes, change chunk ratio and add chunk_tensor - if idx in start_nodes: - within_chunk = True - chunk_ratio = self._get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx]) - act_memory += self._get_output_node_size(node_list[end_nodes[region_idx]]) / (1024 ** 2) + if use_chunk and idx in start_nodes: + chunk_within = True + chunk_ratio = self._get_chunk_ratio(node, chunk_dims[chunk_region_idx], chunk_sizes[chunk_region_idx]) + act_memory += self._get_output_node_size(node_list[end_nodes[chunk_region_idx]]) / (1024 ** 2) # if node is placeholder, just add the size of the node if node.op == 'placeholder': @@ -586,22 +547,28 @@ class MemoryEstimator(object): act_memory_peak_log.append(act_memory) # delete useless memory act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2) - if within_chunk: + if chunk_within: act_memory -= self._get_chunk_delete_node_size( node, user_to_last_uses, chunk_ratio, node_list, - start_nodes[region_idx], end_nodes[region_idx]) / (1024 ** 2) + start_nodes[chunk_region_idx], end_nodes[chunk_region_idx]) / (1024 ** 2) else: act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) - - if idx in end_nodes: + + # log active node + self._add_active_node(node, active_node_list) + self._remove_deactive_node(node, user_to_last_uses, active_node_list) + + # if node in chunk end nodes, restore chunk settings + if use_chunk and idx in end_nodes: act_memory -= self._get_output_node_size(node) * chunk_ratio / (1024 ** 2) - within_chunk = False + chunk_within = False chunk_ratio = 1 - region_idx += 1 + chunk_region_idx += 1 act_memory_after_node_log.append(act_memory) + active_node_list_log.append(copy.deepcopy(active_node_list)) - print("chunk") + print("with chunk" if use_chunk else "without chunk") self._print_mem_log(act_memory_peak_log, node_list, "peak") self._print_mem_log(act_memory_after_node_log, node_list, "after") @@ -725,7 +692,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v memory_estimator = MemoryEstimator() memory_estimator.estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2]) - memory_estimator.estimate_inference_mem(meta_graph) + memory_estimator.estimate_chunk_inference_mem(meta_graph) node_index_tracer = NodeIndexTracer(meta_graph) node_index_tracer.trace_node_idx()