basic chunk

pull/2364/head
oahzxl 2022-11-04 11:18:09 +08:00
parent f8aeecef46
commit c35718e8db
2 changed files with 95 additions and 45 deletions

View File

@ -18,16 +18,61 @@ else:
__all__ = ['python_code_with_activation_checkpoint']
def _gen_loop_start(to_keep, chunk_size=2):
context = "chunk_result = []; chunk_size = %d\nfor gen_loop_idx in range(0, %s.shape[0], chunk_size):\n" % (chunk_size, to_keep[0])
context += " chunk_tensor = " + to_keep + "[gen_loop_idx:gen_loop_idx + chunk_size, :]\n"
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
new_shape = "["
for idx, i in enumerate(shape):
if idx == chunk_dim:
new_shape += "%s:%s + chunk_size" % (chunk_idx_name, chunk_idx_name)
else:
new_shape += ":"
new_shape += ", "
new_shape = new_shape[:-2] + "]"
return new_shape
def _get_first_non_single_dim(shape):
for idx, i in enumerate(shape):
if i == 1:
continue
else:
return idx
raise RuntimeError("can not get first non single dim for shape", shape)
def _gen_loop_start(chunk_input_meta, chunk_output, chunk_size=2):
if len(chunk_input_meta) == 1:
node = chunk_input_meta[0]
node_shape = node.meta['tensor_meta'].shape
chunk_dim = _get_first_non_single_dim(node_shape)
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape)
out_shape = str(list(chunk_output.meta['tensor_meta'].shape))
context = "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range" % (
out_shape, node.name, node.name, chunk_size)
context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim)
context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice)
else:
raise NotImplementedError("input with size %d not implemented" % len(chunk_input_meta))
return context
def _gen_loop_end(final_name, to_keep):
context = " chunk_result.append(" + final_name + ")\n"
context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n"
context += final_name + " = chunk_result; chunk_result = None\n"
def _gen_loop_end(chunk_outputs, chunk_inputs, node_list):
chunk_inputs_name = chunk_inputs[0].name
chunk_outputs_name = chunk_outputs.name
chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list)
chunk_output_shape = chunk_outputs.meta['tensor_meta'].shape
chunk_dim = _get_first_non_single_dim(chunk_output_shape)
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", chunk_output_shape)
context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name)
context += chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None"
# determine if its the last use for chunk input
users_name = list(chunk_inputs[0].users.keys())
if all([_find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in users_name]):
context += "; %s = None" % chunk_inputs_name
context += "\n"
return context
@ -44,7 +89,7 @@ def _find_input_and_output_nodes(nodes: List[Node]):
for input_node in node._input_nodes.keys():
node_repr = repr(input_node)
if input_node not in nodes and node_repr not in input_nodes:
input_nodes.append(node_repr)
input_nodes.append(input_node)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
@ -52,11 +97,18 @@ def _find_input_and_output_nodes(nodes: List[Node]):
for output_node in node.users.keys():
node_repr = repr(node)
if output_node not in nodes and node_repr not in output_nodes:
output_nodes.append(node_repr)
output_nodes.append(output_node)
return input_nodes, output_nodes
def _find_idx_by_name(name, nodes_list):
for idx, node in enumerate(nodes_list):
if node.name == name:
return idx
raise RuntimeError("name %s not found in node list" % name)
def _find_offload_regions(nodes: List[Node]):
"""This function is to find the offload regions
In pofo algorithm, during annotation, we will annotate the offload region with the
@ -290,7 +342,7 @@ def emit_ckpt_func(body,
body.append(usage)
def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes):
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
this function to emit the activation checkpoint codes.
@ -304,7 +356,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
"""
# find the offload regions
chunk_regions = [(1, 4)]
chunk_regions = [(2, 5)]
chunk_starts = [item[0] for item in chunk_regions]
chunk_ends = [item[1] for item in chunk_regions]
chunk_inputs = []
@ -319,48 +371,46 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
chunk_inputs.append(inputs)
chunk_outputs.append(outputs)
chunk_inputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs]
chunk_outputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs]
chunk_inputs_names = []
for i in chunk_inputs:
for j in i:
chunk_inputs_names.append(j.name)
# this flag is to prevent repeated insert of save tensors
# hooks definition in ckpt_func
node_idx = 0
chunk_var = []
region_idx = 0
while node_idx < len(node_list):
# break if we finish the processing all the nodes
if node_idx >= len(node_list):
break
node = node_list[node_idx]
# process node in forward function
else:
node = node_list[node_idx]
if node_idx in chunk_starts:
within_chunk_region = True
# add for loop
chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]]
body.append(_gen_loop_start(chunk_input_meta, node_list[chunk_ends[region_idx]]))
if within_chunk_region:
emit_node_func(node, body)
# replace input var with chunk var
if node_idx in chunk_starts:
within_chunk_region = True
body[-1] = body[-1].replace("("+ chunk_inputs[region_idx][0].name +")", '(chunk_tensor)')
body[-1] = ' ' + body[-1]
delete_unused_value_func(node, body, chunk_inputs_names)
# save chunk input var, dont delete it
chunk_var.append(node.args[0].name)
# add for loop
body.append(_gen_loop_start(chunk_var[0]))
if within_chunk_region:
emit_node_func(node, body)
# replace input var with chunk var
if node_idx in chunk_starts:
body[-1] = body[-1].replace("("+ chunk_var[0] +")", '(chunk_tensor)')
body[-1] = ' ' + body[-1]
delete_unused_value_func(node, body, chunk_var)
else:
emit_node_func(node, body)
if node_idx not in chunk_inputs:
delete_unused_value_func(node, body, chunk_inputs_names)
else:
emit_node_func(node, body)
if node_idx not in chunk_inputs:
delete_unused_value_func(node, body, chunk_var)
if node_idx in chunk_ends:
body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list))
within_chunk_region = False
region_idx += 1
if node_idx in chunk_ends:
body.append(_gen_loop_end(node.name, chunk_var))
chunk_var = []
within_chunk_region = False
node_idx += 1
node_idx += 1
if CODEGEN_AVAILABLE:
@ -562,7 +612,7 @@ if CODEGEN_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values)
emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values, self.meta_node)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body

View File

@ -70,7 +70,7 @@ def _run_offload_codegen(rank):
# setattr(node, "activation_offload", [0, True, False])
codegen = ChunkCodeGen(gm_prop)
# graph.set_codegen(codegen)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph)
gm.recompile()