diff --git a/chunk_codegen.py b/chunk_codegen.py index 01b29cb33..baf207795 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -85,25 +85,97 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule): act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log] act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log] - # for i in act_memory_peak_log: - # print("%.2f " % i, end='') - # print("\n") - # for i in act_memory_after_node_log: - # print("%.2f " % i, end='') - # print("\n") + print("no chunk") + _print_mem_log(act_memory_peak_log, "peak") + _print_mem_log(act_memory_after_node_log, "after") param_memory = parameter_size(gm) return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2) -def _estimate_chunk_forward_mem(gm: torch.fx.GraphModule, start_node, end_node, chunk_size): - node_size = 0 - param_size = 0 - for node in gm.graph.nodes: - node_size += calculate_fwd_tmp(node) - node_size += calculate_fwd_out(node) - param_size = parameter_size(gm) - return (node_size + param_size) / 1024**2, param_size / 1024**2 +def _get_chunk_ratio(node, chunk_dim, chunk_size): + shape = node.meta['tensor_meta'].shape + chunk_ratio = float(chunk_size) / shape[chunk_dim] + return chunk_ratio + + +def _get_chunk_delete_node_size(user, user_to_last_uses, chunk_ratio, node_list, start_node, end_node): + if user.op in ('placeholder', 'output'): + return 0 + nodes_to_delete = user_to_last_uses.get(user, []) + delete_size = 0 + for n in nodes_to_delete: + node_idx = _find_idx_by_name(n.name, node_list) + if start_node <= node_idx < end_node: + delete_size += _get_output_node_size(n) * chunk_ratio + return delete_size + + +def _print_mem_log(log, title=None): + if title: + print("%-8s" % title, end=' ') + for i in log: + print("%.2f " % i, end='') + print("") + + +def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes): + act_memory = 0 + act_memory_peak_log = [] + act_memory_after_node_log = [] + user_to_last_uses = _get_last_usr(list(gm.graph.nodes)) + within_chunk = False + region_idx = 0 + chunk_ratio = 1 # use it to estimate chunk mem + node_list = list(gm.graph.nodes) + + for idx, node in enumerate(node_list): + # if node in chunk start nodes, change chunk ratio and add chunk_tensor + if idx in start_nodes: + within_chunk = True + chunk_ratio = _get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx]) + act_memory += _get_output_node_size(node_list[end_nodes[region_idx]]) + + # if node is placeholder, just add the size of the node + if node.op == 'placeholder': + act_memory += _get_meta_node_size(node) * chunk_ratio + act_memory_peak_log.append(act_memory) + # skip output + elif node.op == 'output': + continue + # node is an operation, calculate tmp, output node and delete node memory + else: + # forward memory + act_memory += calculate_fwd_tmp(node) * chunk_ratio + # act_memory += calculate_fwd_out(node) + act_memory += _get_output_node_size(node) * chunk_ratio + # record max act memory + act_memory_peak_log.append(act_memory) + # delete useless memory + act_memory -= calculate_fwd_tmp(node) * chunk_ratio + if within_chunk: + act_memory -= _get_chunk_delete_node_size( + node, user_to_last_uses, chunk_ratio, node_list, start_nodes[region_idx], end_nodes[region_idx]) + else: + act_memory -= _get_delete_node_size(node, user_to_last_uses) + + if idx in end_nodes: + act_memory -= _get_output_node_size(node) * chunk_ratio + within_chunk = False + chunk_ratio = 1 + region_idx += 1 + + act_memory_after_node_log.append(act_memory) + + act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log] + act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log] + + print("chunk") + _print_mem_log(act_memory_peak_log, "peak") + _print_mem_log(act_memory_after_node_log, "after") + + param_memory = parameter_size(gm) + return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2) def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): @@ -444,7 +516,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v """ # find the offload regions - chunk_regions = [(2, 5)] + chunk_regions = [(2, 6)] chunk_starts = [item[0] for item in chunk_regions] chunk_ends = [item[1] for item in chunk_regions] chunk_inputs = [] @@ -452,6 +524,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v within_chunk_region = False node_list = list(nodes) + _estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2]) _estimate_inference_mem(meta_graph) # find the input and output var names for each offload region