mirror of https://github.com/hpcaitech/ColossalAI
basic chunk
parent
f8aeecef46
commit
c35718e8db
138
chunk_codegen.py
138
chunk_codegen.py
|
@ -18,16 +18,61 @@ else:
|
|||
__all__ = ['python_code_with_activation_checkpoint']
|
||||
|
||||
|
||||
def _gen_loop_start(to_keep, chunk_size=2):
|
||||
context = "chunk_result = []; chunk_size = %d\nfor gen_loop_idx in range(0, %s.shape[0], chunk_size):\n" % (chunk_size, to_keep[0])
|
||||
context += " chunk_tensor = " + to_keep + "[gen_loop_idx:gen_loop_idx + chunk_size, :]\n"
|
||||
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
|
||||
new_shape = "["
|
||||
for idx, i in enumerate(shape):
|
||||
if idx == chunk_dim:
|
||||
new_shape += "%s:%s + chunk_size" % (chunk_idx_name, chunk_idx_name)
|
||||
else:
|
||||
new_shape += ":"
|
||||
new_shape += ", "
|
||||
new_shape = new_shape[:-2] + "]"
|
||||
return new_shape
|
||||
|
||||
|
||||
def _get_first_non_single_dim(shape):
|
||||
for idx, i in enumerate(shape):
|
||||
if i == 1:
|
||||
continue
|
||||
else:
|
||||
return idx
|
||||
raise RuntimeError("can not get first non single dim for shape", shape)
|
||||
|
||||
|
||||
def _gen_loop_start(chunk_input_meta, chunk_output, chunk_size=2):
|
||||
if len(chunk_input_meta) == 1:
|
||||
node = chunk_input_meta[0]
|
||||
node_shape = node.meta['tensor_meta'].shape
|
||||
chunk_dim = _get_first_non_single_dim(node_shape)
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape)
|
||||
out_shape = str(list(chunk_output.meta['tensor_meta'].shape))
|
||||
|
||||
context = "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range" % (
|
||||
out_shape, node.name, node.name, chunk_size)
|
||||
context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim)
|
||||
context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice)
|
||||
else:
|
||||
raise NotImplementedError("input with size %d not implemented" % len(chunk_input_meta))
|
||||
return context
|
||||
|
||||
|
||||
def _gen_loop_end(final_name, to_keep):
|
||||
context = " chunk_result.append(" + final_name + ")\n"
|
||||
context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n"
|
||||
context += final_name + " = chunk_result; chunk_result = None\n"
|
||||
def _gen_loop_end(chunk_outputs, chunk_inputs, node_list):
|
||||
chunk_inputs_name = chunk_inputs[0].name
|
||||
chunk_outputs_name = chunk_outputs.name
|
||||
chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list)
|
||||
chunk_output_shape = chunk_outputs.meta['tensor_meta'].shape
|
||||
chunk_dim = _get_first_non_single_dim(chunk_output_shape)
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", chunk_output_shape)
|
||||
context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name)
|
||||
|
||||
context += chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None"
|
||||
|
||||
# determine if its the last use for chunk input
|
||||
users_name = list(chunk_inputs[0].users.keys())
|
||||
if all([_find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in users_name]):
|
||||
context += "; %s = None" % chunk_inputs_name
|
||||
|
||||
context += "\n"
|
||||
return context
|
||||
|
||||
|
||||
|
@ -44,7 +89,7 @@ def _find_input_and_output_nodes(nodes: List[Node]):
|
|||
for input_node in node._input_nodes.keys():
|
||||
node_repr = repr(input_node)
|
||||
if input_node not in nodes and node_repr not in input_nodes:
|
||||
input_nodes.append(node_repr)
|
||||
input_nodes.append(input_node)
|
||||
|
||||
# if a node has a user node which is not in the node list
|
||||
# we treat that user node as the node receiving the current node output
|
||||
|
@ -52,11 +97,18 @@ def _find_input_and_output_nodes(nodes: List[Node]):
|
|||
for output_node in node.users.keys():
|
||||
node_repr = repr(node)
|
||||
if output_node not in nodes and node_repr not in output_nodes:
|
||||
output_nodes.append(node_repr)
|
||||
output_nodes.append(output_node)
|
||||
|
||||
return input_nodes, output_nodes
|
||||
|
||||
|
||||
def _find_idx_by_name(name, nodes_list):
|
||||
for idx, node in enumerate(nodes_list):
|
||||
if node.name == name:
|
||||
return idx
|
||||
raise RuntimeError("name %s not found in node list" % name)
|
||||
|
||||
|
||||
def _find_offload_regions(nodes: List[Node]):
|
||||
"""This function is to find the offload regions
|
||||
In pofo algorithm, during annotation, we will annotate the offload region with the
|
||||
|
@ -290,7 +342,7 @@ def emit_ckpt_func(body,
|
|||
body.append(usage)
|
||||
|
||||
|
||||
def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
|
||||
def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes):
|
||||
"""Emit code with nested activation checkpoint
|
||||
When we detect some of the node.activation_checkpoint is a List, we will use
|
||||
this function to emit the activation checkpoint codes.
|
||||
|
@ -304,7 +356,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
"""
|
||||
|
||||
# find the offload regions
|
||||
chunk_regions = [(1, 4)]
|
||||
chunk_regions = [(2, 5)]
|
||||
chunk_starts = [item[0] for item in chunk_regions]
|
||||
chunk_ends = [item[1] for item in chunk_regions]
|
||||
chunk_inputs = []
|
||||
|
@ -319,48 +371,46 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
|
||||
chunk_inputs.append(inputs)
|
||||
chunk_outputs.append(outputs)
|
||||
|
||||
chunk_inputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs]
|
||||
chunk_outputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs]
|
||||
chunk_inputs_names = []
|
||||
for i in chunk_inputs:
|
||||
for j in i:
|
||||
chunk_inputs_names.append(j.name)
|
||||
|
||||
# this flag is to prevent repeated insert of save tensors
|
||||
# hooks definition in ckpt_func
|
||||
node_idx = 0
|
||||
chunk_var = []
|
||||
region_idx = 0
|
||||
while node_idx < len(node_list):
|
||||
# break if we finish the processing all the nodes
|
||||
if node_idx >= len(node_list):
|
||||
break
|
||||
node = node_list[node_idx]
|
||||
|
||||
# process node in forward function
|
||||
else:
|
||||
node = node_list[node_idx]
|
||||
if node_idx in chunk_starts:
|
||||
within_chunk_region = True
|
||||
|
||||
# add for loop
|
||||
chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]]
|
||||
body.append(_gen_loop_start(chunk_input_meta, node_list[chunk_ends[region_idx]]))
|
||||
|
||||
if within_chunk_region:
|
||||
emit_node_func(node, body)
|
||||
# replace input var with chunk var
|
||||
if node_idx in chunk_starts:
|
||||
within_chunk_region = True
|
||||
body[-1] = body[-1].replace("("+ chunk_inputs[region_idx][0].name +")", '(chunk_tensor)')
|
||||
body[-1] = ' ' + body[-1]
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
|
||||
# save chunk input var, dont delete it
|
||||
chunk_var.append(node.args[0].name)
|
||||
|
||||
# add for loop
|
||||
body.append(_gen_loop_start(chunk_var[0]))
|
||||
|
||||
if within_chunk_region:
|
||||
emit_node_func(node, body)
|
||||
# replace input var with chunk var
|
||||
if node_idx in chunk_starts:
|
||||
body[-1] = body[-1].replace("("+ chunk_var[0] +")", '(chunk_tensor)')
|
||||
body[-1] = ' ' + body[-1]
|
||||
delete_unused_value_func(node, body, chunk_var)
|
||||
else:
|
||||
emit_node_func(node, body)
|
||||
if node_idx not in chunk_inputs:
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
|
||||
else:
|
||||
emit_node_func(node, body)
|
||||
if node_idx not in chunk_inputs:
|
||||
delete_unused_value_func(node, body, chunk_var)
|
||||
if node_idx in chunk_ends:
|
||||
body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list))
|
||||
within_chunk_region = False
|
||||
region_idx += 1
|
||||
|
||||
if node_idx in chunk_ends:
|
||||
body.append(_gen_loop_end(node.name, chunk_var))
|
||||
chunk_var = []
|
||||
within_chunk_region = False
|
||||
|
||||
node_idx += 1
|
||||
node_idx += 1
|
||||
|
||||
|
||||
if CODEGEN_AVAILABLE:
|
||||
|
@ -562,7 +612,7 @@ if CODEGEN_AVAILABLE:
|
|||
|
||||
# if any node has a list of labels for activation_checkpoint, we
|
||||
# will use nested type of activation checkpoint codegen
|
||||
emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
||||
emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values, self.meta_node)
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
|
|
|
@ -70,7 +70,7 @@ def _run_offload_codegen(rank):
|
|||
# setattr(node, "activation_offload", [0, True, False])
|
||||
|
||||
codegen = ChunkCodeGen(gm_prop)
|
||||
# graph.set_codegen(codegen)
|
||||
graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
|
|
Loading…
Reference in New Issue