basic chunk

pull/2364/head
oahzxl 2022-11-04 11:18:09 +08:00
parent f8aeecef46
commit c35718e8db
2 changed files with 95 additions and 45 deletions

View File

@ -18,16 +18,61 @@ else:
__all__ = ['python_code_with_activation_checkpoint'] __all__ = ['python_code_with_activation_checkpoint']
def _gen_loop_start(to_keep, chunk_size=2): def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
context = "chunk_result = []; chunk_size = %d\nfor gen_loop_idx in range(0, %s.shape[0], chunk_size):\n" % (chunk_size, to_keep[0]) new_shape = "["
context += " chunk_tensor = " + to_keep + "[gen_loop_idx:gen_loop_idx + chunk_size, :]\n" for idx, i in enumerate(shape):
if idx == chunk_dim:
new_shape += "%s:%s + chunk_size" % (chunk_idx_name, chunk_idx_name)
else:
new_shape += ":"
new_shape += ", "
new_shape = new_shape[:-2] + "]"
return new_shape
def _get_first_non_single_dim(shape):
for idx, i in enumerate(shape):
if i == 1:
continue
else:
return idx
raise RuntimeError("can not get first non single dim for shape", shape)
def _gen_loop_start(chunk_input_meta, chunk_output, chunk_size=2):
if len(chunk_input_meta) == 1:
node = chunk_input_meta[0]
node_shape = node.meta['tensor_meta'].shape
chunk_dim = _get_first_non_single_dim(node_shape)
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape)
out_shape = str(list(chunk_output.meta['tensor_meta'].shape))
context = "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range" % (
out_shape, node.name, node.name, chunk_size)
context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim)
context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice)
else:
raise NotImplementedError("input with size %d not implemented" % len(chunk_input_meta))
return context return context
def _gen_loop_end(final_name, to_keep): def _gen_loop_end(chunk_outputs, chunk_inputs, node_list):
context = " chunk_result.append(" + final_name + ")\n" chunk_inputs_name = chunk_inputs[0].name
context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n" chunk_outputs_name = chunk_outputs.name
context += final_name + " = chunk_result; chunk_result = None\n" chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list)
chunk_output_shape = chunk_outputs.meta['tensor_meta'].shape
chunk_dim = _get_first_non_single_dim(chunk_output_shape)
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", chunk_output_shape)
context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name)
context += chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None"
# determine if its the last use for chunk input
users_name = list(chunk_inputs[0].users.keys())
if all([_find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in users_name]):
context += "; %s = None" % chunk_inputs_name
context += "\n"
return context return context
@ -44,7 +89,7 @@ def _find_input_and_output_nodes(nodes: List[Node]):
for input_node in node._input_nodes.keys(): for input_node in node._input_nodes.keys():
node_repr = repr(input_node) node_repr = repr(input_node)
if input_node not in nodes and node_repr not in input_nodes: if input_node not in nodes and node_repr not in input_nodes:
input_nodes.append(node_repr) input_nodes.append(input_node)
# if a node has a user node which is not in the node list # if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output # we treat that user node as the node receiving the current node output
@ -52,11 +97,18 @@ def _find_input_and_output_nodes(nodes: List[Node]):
for output_node in node.users.keys(): for output_node in node.users.keys():
node_repr = repr(node) node_repr = repr(node)
if output_node not in nodes and node_repr not in output_nodes: if output_node not in nodes and node_repr not in output_nodes:
output_nodes.append(node_repr) output_nodes.append(output_node)
return input_nodes, output_nodes return input_nodes, output_nodes
def _find_idx_by_name(name, nodes_list):
for idx, node in enumerate(nodes_list):
if node.name == name:
return idx
raise RuntimeError("name %s not found in node list" % name)
def _find_offload_regions(nodes: List[Node]): def _find_offload_regions(nodes: List[Node]):
"""This function is to find the offload regions """This function is to find the offload regions
In pofo algorithm, during annotation, we will annotate the offload region with the In pofo algorithm, during annotation, we will annotate the offload region with the
@ -290,7 +342,7 @@ def emit_ckpt_func(body,
body.append(usage) body.append(usage)
def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes):
"""Emit code with nested activation checkpoint """Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use When we detect some of the node.activation_checkpoint is a List, we will use
this function to emit the activation checkpoint codes. this function to emit the activation checkpoint codes.
@ -304,7 +356,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
""" """
# find the offload regions # find the offload regions
chunk_regions = [(1, 4)] chunk_regions = [(2, 5)]
chunk_starts = [item[0] for item in chunk_regions] chunk_starts = [item[0] for item in chunk_regions]
chunk_ends = [item[1] for item in chunk_regions] chunk_ends = [item[1] for item in chunk_regions]
chunk_inputs = [] chunk_inputs = []
@ -319,46 +371,44 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
inputs, outputs = _find_input_and_output_nodes(offload_node_list) inputs, outputs = _find_input_and_output_nodes(offload_node_list)
chunk_inputs.append(inputs) chunk_inputs.append(inputs)
chunk_outputs.append(outputs) chunk_outputs.append(outputs)
chunk_inputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs]
chunk_outputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs]
chunk_inputs_names = []
for i in chunk_inputs:
for j in i:
chunk_inputs_names.append(j.name)
# this flag is to prevent repeated insert of save tensors # this flag is to prevent repeated insert of save tensors
# hooks definition in ckpt_func # hooks definition in ckpt_func
node_idx = 0 node_idx = 0
chunk_var = [] region_idx = 0
while node_idx < len(node_list): while node_idx < len(node_list):
# break if we finish the processing all the nodes
if node_idx >= len(node_list):
break
# process node in forward function
else:
node = node_list[node_idx] node = node_list[node_idx]
if node_idx in chunk_starts: if node_idx in chunk_starts:
within_chunk_region = True within_chunk_region = True
# save chunk input var, dont delete it
chunk_var.append(node.args[0].name)
# add for loop # add for loop
body.append(_gen_loop_start(chunk_var[0])) chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]]
body.append(_gen_loop_start(chunk_input_meta, node_list[chunk_ends[region_idx]]))
if within_chunk_region: if within_chunk_region:
emit_node_func(node, body) emit_node_func(node, body)
# replace input var with chunk var # replace input var with chunk var
if node_idx in chunk_starts: if node_idx in chunk_starts:
body[-1] = body[-1].replace("("+ chunk_var[0] +")", '(chunk_tensor)') body[-1] = body[-1].replace("("+ chunk_inputs[region_idx][0].name +")", '(chunk_tensor)')
body[-1] = ' ' + body[-1] body[-1] = ' ' + body[-1]
delete_unused_value_func(node, body, chunk_var) delete_unused_value_func(node, body, chunk_inputs_names)
else: else:
emit_node_func(node, body) emit_node_func(node, body)
if node_idx not in chunk_inputs: if node_idx not in chunk_inputs:
delete_unused_value_func(node, body, chunk_var) delete_unused_value_func(node, body, chunk_inputs_names)
if node_idx in chunk_ends: if node_idx in chunk_ends:
body.append(_gen_loop_end(node.name, chunk_var)) body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list))
chunk_var = []
within_chunk_region = False within_chunk_region = False
region_idx += 1
node_idx += 1 node_idx += 1
@ -562,7 +612,7 @@ if CODEGEN_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we # if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen # will use nested type of activation checkpoint codegen
emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values) emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values, self.meta_node)
if len(body) == 0: if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body # If the Graph has no non-placeholder nodes, no lines for the body

View File

@ -70,7 +70,7 @@ def _run_offload_codegen(rank):
# setattr(node, "activation_offload", [0, True, False]) # setattr(node, "activation_offload", [0, True, False])
codegen = ChunkCodeGen(gm_prop) codegen = ChunkCodeGen(gm_prop)
# graph.set_codegen(codegen) graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
gm.recompile() gm.recompile()