mirror of https://github.com/hpcaitech/ColossalAI
basic chunk
parent
f8aeecef46
commit
c35718e8db
138
chunk_codegen.py
138
chunk_codegen.py
|
@ -18,16 +18,61 @@ else:
|
||||||
__all__ = ['python_code_with_activation_checkpoint']
|
__all__ = ['python_code_with_activation_checkpoint']
|
||||||
|
|
||||||
|
|
||||||
def _gen_loop_start(to_keep, chunk_size=2):
|
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
|
||||||
context = "chunk_result = []; chunk_size = %d\nfor gen_loop_idx in range(0, %s.shape[0], chunk_size):\n" % (chunk_size, to_keep[0])
|
new_shape = "["
|
||||||
context += " chunk_tensor = " + to_keep + "[gen_loop_idx:gen_loop_idx + chunk_size, :]\n"
|
for idx, i in enumerate(shape):
|
||||||
|
if idx == chunk_dim:
|
||||||
|
new_shape += "%s:%s + chunk_size" % (chunk_idx_name, chunk_idx_name)
|
||||||
|
else:
|
||||||
|
new_shape += ":"
|
||||||
|
new_shape += ", "
|
||||||
|
new_shape = new_shape[:-2] + "]"
|
||||||
|
return new_shape
|
||||||
|
|
||||||
|
|
||||||
|
def _get_first_non_single_dim(shape):
|
||||||
|
for idx, i in enumerate(shape):
|
||||||
|
if i == 1:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
return idx
|
||||||
|
raise RuntimeError("can not get first non single dim for shape", shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_loop_start(chunk_input_meta, chunk_output, chunk_size=2):
|
||||||
|
if len(chunk_input_meta) == 1:
|
||||||
|
node = chunk_input_meta[0]
|
||||||
|
node_shape = node.meta['tensor_meta'].shape
|
||||||
|
chunk_dim = _get_first_non_single_dim(node_shape)
|
||||||
|
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape)
|
||||||
|
out_shape = str(list(chunk_output.meta['tensor_meta'].shape))
|
||||||
|
|
||||||
|
context = "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range" % (
|
||||||
|
out_shape, node.name, node.name, chunk_size)
|
||||||
|
context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim)
|
||||||
|
context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("input with size %d not implemented" % len(chunk_input_meta))
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def _gen_loop_end(final_name, to_keep):
|
def _gen_loop_end(chunk_outputs, chunk_inputs, node_list):
|
||||||
context = " chunk_result.append(" + final_name + ")\n"
|
chunk_inputs_name = chunk_inputs[0].name
|
||||||
context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n"
|
chunk_outputs_name = chunk_outputs.name
|
||||||
context += final_name + " = chunk_result; chunk_result = None\n"
|
chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list)
|
||||||
|
chunk_output_shape = chunk_outputs.meta['tensor_meta'].shape
|
||||||
|
chunk_dim = _get_first_non_single_dim(chunk_output_shape)
|
||||||
|
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", chunk_output_shape)
|
||||||
|
context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name)
|
||||||
|
|
||||||
|
context += chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None"
|
||||||
|
|
||||||
|
# determine if its the last use for chunk input
|
||||||
|
users_name = list(chunk_inputs[0].users.keys())
|
||||||
|
if all([_find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in users_name]):
|
||||||
|
context += "; %s = None" % chunk_inputs_name
|
||||||
|
|
||||||
|
context += "\n"
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,7 +89,7 @@ def _find_input_and_output_nodes(nodes: List[Node]):
|
||||||
for input_node in node._input_nodes.keys():
|
for input_node in node._input_nodes.keys():
|
||||||
node_repr = repr(input_node)
|
node_repr = repr(input_node)
|
||||||
if input_node not in nodes and node_repr not in input_nodes:
|
if input_node not in nodes and node_repr not in input_nodes:
|
||||||
input_nodes.append(node_repr)
|
input_nodes.append(input_node)
|
||||||
|
|
||||||
# if a node has a user node which is not in the node list
|
# if a node has a user node which is not in the node list
|
||||||
# we treat that user node as the node receiving the current node output
|
# we treat that user node as the node receiving the current node output
|
||||||
|
@ -52,11 +97,18 @@ def _find_input_and_output_nodes(nodes: List[Node]):
|
||||||
for output_node in node.users.keys():
|
for output_node in node.users.keys():
|
||||||
node_repr = repr(node)
|
node_repr = repr(node)
|
||||||
if output_node not in nodes and node_repr not in output_nodes:
|
if output_node not in nodes and node_repr not in output_nodes:
|
||||||
output_nodes.append(node_repr)
|
output_nodes.append(output_node)
|
||||||
|
|
||||||
return input_nodes, output_nodes
|
return input_nodes, output_nodes
|
||||||
|
|
||||||
|
|
||||||
|
def _find_idx_by_name(name, nodes_list):
|
||||||
|
for idx, node in enumerate(nodes_list):
|
||||||
|
if node.name == name:
|
||||||
|
return idx
|
||||||
|
raise RuntimeError("name %s not found in node list" % name)
|
||||||
|
|
||||||
|
|
||||||
def _find_offload_regions(nodes: List[Node]):
|
def _find_offload_regions(nodes: List[Node]):
|
||||||
"""This function is to find the offload regions
|
"""This function is to find the offload regions
|
||||||
In pofo algorithm, during annotation, we will annotate the offload region with the
|
In pofo algorithm, during annotation, we will annotate the offload region with the
|
||||||
|
@ -290,7 +342,7 @@ def emit_ckpt_func(body,
|
||||||
body.append(usage)
|
body.append(usage)
|
||||||
|
|
||||||
|
|
||||||
def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
|
def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes):
|
||||||
"""Emit code with nested activation checkpoint
|
"""Emit code with nested activation checkpoint
|
||||||
When we detect some of the node.activation_checkpoint is a List, we will use
|
When we detect some of the node.activation_checkpoint is a List, we will use
|
||||||
this function to emit the activation checkpoint codes.
|
this function to emit the activation checkpoint codes.
|
||||||
|
@ -304,7 +356,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# find the offload regions
|
# find the offload regions
|
||||||
chunk_regions = [(1, 4)]
|
chunk_regions = [(2, 5)]
|
||||||
chunk_starts = [item[0] for item in chunk_regions]
|
chunk_starts = [item[0] for item in chunk_regions]
|
||||||
chunk_ends = [item[1] for item in chunk_regions]
|
chunk_ends = [item[1] for item in chunk_regions]
|
||||||
chunk_inputs = []
|
chunk_inputs = []
|
||||||
|
@ -319,48 +371,46 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
||||||
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
|
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
|
||||||
chunk_inputs.append(inputs)
|
chunk_inputs.append(inputs)
|
||||||
chunk_outputs.append(outputs)
|
chunk_outputs.append(outputs)
|
||||||
|
chunk_inputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs]
|
||||||
|
chunk_outputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs]
|
||||||
|
chunk_inputs_names = []
|
||||||
|
for i in chunk_inputs:
|
||||||
|
for j in i:
|
||||||
|
chunk_inputs_names.append(j.name)
|
||||||
|
|
||||||
# this flag is to prevent repeated insert of save tensors
|
# this flag is to prevent repeated insert of save tensors
|
||||||
# hooks definition in ckpt_func
|
# hooks definition in ckpt_func
|
||||||
node_idx = 0
|
node_idx = 0
|
||||||
chunk_var = []
|
region_idx = 0
|
||||||
while node_idx < len(node_list):
|
while node_idx < len(node_list):
|
||||||
# break if we finish the processing all the nodes
|
node = node_list[node_idx]
|
||||||
if node_idx >= len(node_list):
|
|
||||||
break
|
|
||||||
|
|
||||||
# process node in forward function
|
if node_idx in chunk_starts:
|
||||||
else:
|
within_chunk_region = True
|
||||||
node = node_list[node_idx]
|
|
||||||
|
# add for loop
|
||||||
|
chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]]
|
||||||
|
body.append(_gen_loop_start(chunk_input_meta, node_list[chunk_ends[region_idx]]))
|
||||||
|
|
||||||
|
if within_chunk_region:
|
||||||
|
emit_node_func(node, body)
|
||||||
|
# replace input var with chunk var
|
||||||
if node_idx in chunk_starts:
|
if node_idx in chunk_starts:
|
||||||
within_chunk_region = True
|
body[-1] = body[-1].replace("("+ chunk_inputs[region_idx][0].name +")", '(chunk_tensor)')
|
||||||
|
body[-1] = ' ' + body[-1]
|
||||||
|
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||||
|
|
||||||
# save chunk input var, dont delete it
|
else:
|
||||||
chunk_var.append(node.args[0].name)
|
emit_node_func(node, body)
|
||||||
|
if node_idx not in chunk_inputs:
|
||||||
# add for loop
|
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||||
body.append(_gen_loop_start(chunk_var[0]))
|
|
||||||
|
|
||||||
if within_chunk_region:
|
|
||||||
emit_node_func(node, body)
|
|
||||||
# replace input var with chunk var
|
|
||||||
if node_idx in chunk_starts:
|
|
||||||
body[-1] = body[-1].replace("("+ chunk_var[0] +")", '(chunk_tensor)')
|
|
||||||
body[-1] = ' ' + body[-1]
|
|
||||||
delete_unused_value_func(node, body, chunk_var)
|
|
||||||
|
|
||||||
else:
|
if node_idx in chunk_ends:
|
||||||
emit_node_func(node, body)
|
body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list))
|
||||||
if node_idx not in chunk_inputs:
|
within_chunk_region = False
|
||||||
delete_unused_value_func(node, body, chunk_var)
|
region_idx += 1
|
||||||
|
|
||||||
if node_idx in chunk_ends:
|
node_idx += 1
|
||||||
body.append(_gen_loop_end(node.name, chunk_var))
|
|
||||||
chunk_var = []
|
|
||||||
within_chunk_region = False
|
|
||||||
|
|
||||||
node_idx += 1
|
|
||||||
|
|
||||||
|
|
||||||
if CODEGEN_AVAILABLE:
|
if CODEGEN_AVAILABLE:
|
||||||
|
@ -562,7 +612,7 @@ if CODEGEN_AVAILABLE:
|
||||||
|
|
||||||
# if any node has a list of labels for activation_checkpoint, we
|
# if any node has a list of labels for activation_checkpoint, we
|
||||||
# will use nested type of activation checkpoint codegen
|
# will use nested type of activation checkpoint codegen
|
||||||
emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values, self.meta_node)
|
||||||
|
|
||||||
if len(body) == 0:
|
if len(body) == 0:
|
||||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||||
|
|
|
@ -70,7 +70,7 @@ def _run_offload_codegen(rank):
|
||||||
# setattr(node, "activation_offload", [0, True, False])
|
# setattr(node, "activation_offload", [0, True, False])
|
||||||
|
|
||||||
codegen = ChunkCodeGen(gm_prop)
|
codegen = ChunkCodeGen(gm_prop)
|
||||||
# graph.set_codegen(codegen)
|
graph.set_codegen(codegen)
|
||||||
gm = ColoGraphModule(model, graph)
|
gm = ColoGraphModule(model, graph)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue