diff --git a/chunk_codegen.py b/chunk_codegen.py index 1f336eb2b..1267f64cb 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -18,16 +18,61 @@ else: __all__ = ['python_code_with_activation_checkpoint'] -def _gen_loop_start(to_keep, chunk_size=2): - context = "chunk_result = []; chunk_size = %d\nfor gen_loop_idx in range(0, %s.shape[0], chunk_size):\n" % (chunk_size, to_keep[0]) - context += " chunk_tensor = " + to_keep + "[gen_loop_idx:gen_loop_idx + chunk_size, :]\n" +def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): + new_shape = "[" + for idx, i in enumerate(shape): + if idx == chunk_dim: + new_shape += "%s:%s + chunk_size" % (chunk_idx_name, chunk_idx_name) + else: + new_shape += ":" + new_shape += ", " + new_shape = new_shape[:-2] + "]" + return new_shape + + +def _get_first_non_single_dim(shape): + for idx, i in enumerate(shape): + if i == 1: + continue + else: + return idx + raise RuntimeError("can not get first non single dim for shape", shape) + + +def _gen_loop_start(chunk_input_meta, chunk_output, chunk_size=2): + if len(chunk_input_meta) == 1: + node = chunk_input_meta[0] + node_shape = node.meta['tensor_meta'].shape + chunk_dim = _get_first_non_single_dim(node_shape) + chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape) + out_shape = str(list(chunk_output.meta['tensor_meta'].shape)) + + context = "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range" % ( + out_shape, node.name, node.name, chunk_size) + context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim) + context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice) + else: + raise NotImplementedError("input with size %d not implemented" % len(chunk_input_meta)) return context -def _gen_loop_end(final_name, to_keep): - context = " chunk_result.append(" + final_name + ")\n" - context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n" - context += final_name + " = chunk_result; chunk_result = None\n" +def _gen_loop_end(chunk_outputs, chunk_inputs, node_list): + chunk_inputs_name = chunk_inputs[0].name + chunk_outputs_name = chunk_outputs.name + chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list) + chunk_output_shape = chunk_outputs.meta['tensor_meta'].shape + chunk_dim = _get_first_non_single_dim(chunk_output_shape) + chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", chunk_output_shape) + context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name) + + context += chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None" + + # determine if its the last use for chunk input + users_name = list(chunk_inputs[0].users.keys()) + if all([_find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in users_name]): + context += "; %s = None" % chunk_inputs_name + + context += "\n" return context @@ -44,7 +89,7 @@ def _find_input_and_output_nodes(nodes: List[Node]): for input_node in node._input_nodes.keys(): node_repr = repr(input_node) if input_node not in nodes and node_repr not in input_nodes: - input_nodes.append(node_repr) + input_nodes.append(input_node) # if a node has a user node which is not in the node list # we treat that user node as the node receiving the current node output @@ -52,11 +97,18 @@ def _find_input_and_output_nodes(nodes: List[Node]): for output_node in node.users.keys(): node_repr = repr(node) if output_node not in nodes and node_repr not in output_nodes: - output_nodes.append(node_repr) + output_nodes.append(output_node) return input_nodes, output_nodes +def _find_idx_by_name(name, nodes_list): + for idx, node in enumerate(nodes_list): + if node.name == name: + return idx + raise RuntimeError("name %s not found in node list" % name) + + def _find_offload_regions(nodes: List[Node]): """This function is to find the offload regions In pofo algorithm, during annotation, we will annotate the offload region with the @@ -290,7 +342,7 @@ def emit_ckpt_func(body, body.append(usage) -def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): +def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes): """Emit code with nested activation checkpoint When we detect some of the node.activation_checkpoint is a List, we will use this function to emit the activation checkpoint codes. @@ -304,7 +356,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v """ # find the offload regions - chunk_regions = [(1, 4)] + chunk_regions = [(2, 5)] chunk_starts = [item[0] for item in chunk_regions] chunk_ends = [item[1] for item in chunk_regions] chunk_inputs = [] @@ -319,48 +371,46 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v inputs, outputs = _find_input_and_output_nodes(offload_node_list) chunk_inputs.append(inputs) chunk_outputs.append(outputs) - + chunk_inputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs] + chunk_outputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs] + chunk_inputs_names = [] + for i in chunk_inputs: + for j in i: + chunk_inputs_names.append(j.name) + # this flag is to prevent repeated insert of save tensors # hooks definition in ckpt_func node_idx = 0 - chunk_var = [] + region_idx = 0 while node_idx < len(node_list): - # break if we finish the processing all the nodes - if node_idx >= len(node_list): - break + node = node_list[node_idx] - # process node in forward function - else: - node = node_list[node_idx] + if node_idx in chunk_starts: + within_chunk_region = True + + # add for loop + chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]] + body.append(_gen_loop_start(chunk_input_meta, node_list[chunk_ends[region_idx]])) + if within_chunk_region: + emit_node_func(node, body) + # replace input var with chunk var if node_idx in chunk_starts: - within_chunk_region = True + body[-1] = body[-1].replace("("+ chunk_inputs[region_idx][0].name +")", '(chunk_tensor)') + body[-1] = ' ' + body[-1] + delete_unused_value_func(node, body, chunk_inputs_names) - # save chunk input var, dont delete it - chunk_var.append(node.args[0].name) - - # add for loop - body.append(_gen_loop_start(chunk_var[0])) - - if within_chunk_region: - emit_node_func(node, body) - # replace input var with chunk var - if node_idx in chunk_starts: - body[-1] = body[-1].replace("("+ chunk_var[0] +")", '(chunk_tensor)') - body[-1] = ' ' + body[-1] - delete_unused_value_func(node, body, chunk_var) + else: + emit_node_func(node, body) + if node_idx not in chunk_inputs: + delete_unused_value_func(node, body, chunk_inputs_names) - else: - emit_node_func(node, body) - if node_idx not in chunk_inputs: - delete_unused_value_func(node, body, chunk_var) + if node_idx in chunk_ends: + body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list)) + within_chunk_region = False + region_idx += 1 - if node_idx in chunk_ends: - body.append(_gen_loop_end(node.name, chunk_var)) - chunk_var = [] - within_chunk_region = False - - node_idx += 1 + node_idx += 1 if CODEGEN_AVAILABLE: @@ -562,7 +612,7 @@ if CODEGEN_AVAILABLE: # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values) + emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values, self.meta_node) if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index b875b6308..547b983a9 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -70,7 +70,7 @@ def _run_offload_codegen(rank): # setattr(node, "activation_offload", [0, True, False]) codegen = ChunkCodeGen(gm_prop) - # graph.set_codegen(codegen) + graph.set_codegen(codegen) gm = ColoGraphModule(model, graph) gm.recompile()