basic chunk

pull/2364/head
oahzxl 2022-11-02 13:59:48 +08:00
parent 87cddf7e14
commit 78cfe4362b
2 changed files with 41 additions and 40 deletions

View File

@ -46,6 +46,19 @@ def pack_hook_no_input(self, x):
return pack_hook, unpack_hook
def _gen_loop_5(to_keep):
context = "chunk_result = []\nfor gen_loop_idx in range(4):\n"
context += " chunk_tensor = " + to_keep + "[gen_loop_idx, :]\n"
return context
def _gen_loop_5_final(final_name, to_keep):
context = " chunk_result.append(" + final_name + ")\n"
context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n"
context += final_name + " = chunk_result; chunk_result = None\n"
return context
def _gen_save_tensors_hooks_context(offload_input=True) -> str:
"""Generate customized saved_tensors_hooks
@ -410,57 +423,40 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
# this flag is to prevent repeated insert of save tensors
# hooks definition in ckpt_func
is_hook_inserted = False
node_idx = 0
while 1:
to_keep = []
while node_idx < len(node_list):
# break if we finish the processing all the nodes
if node_idx >= len(node_list):
break
# process ckpt_regions
if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
node_idx += len(ckpt_node_list)
# process node in forward function
else:
node = node_list[node_idx]
if node_idx in chunk_starts:
chunk_label = chunk_labels[chunk_starts.index(node_idx)]
_, chunk_input, chunk_bar = chunk_label
# save chunk input var, dont delete it
to_keep.extend(node.args[0].name)
within_chunk_region = True
# insert hook functions if needed
if not is_hook_inserted:
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
is_hook_inserted = True
if chunk_input and chunk_bar:
body.append(_gen_save_on_cpu_context())
elif chunk_input:
for par in chunk_inputs[chunk_label[0]]:
body.append(f"setattr({par}, 'offload', True)\n")
body.append(_gen_save_tensors_hooks_context(offload_input=True))
else:
for par in chunk_inputs[chunk_label[0]]:
body.append(f"setattr({par}, 'offload', False)\n")
body.append(_gen_save_tensors_hooks_context(offload_input=False))
# add for loop
body.append(_gen_loop_5(to_keep[0]))
# change first node's input to new chunked var
node_args = list(node.args)
node_args[0] = 'chunk_tensor'
if within_chunk_region:
emit_node_func(node, body)
body[-1] = ' ' + body[-1]
delete_unused_value_func(node, body)
delete_unused_value_func(node, body, to_keep)
else:
emit_node_func(node, body)
delete_unused_value_func(node, body)
if node_idx not in chunk_inputs:
delete_unused_value_func(node, body, to_keep)
if node_idx in chunk_ends:
body.append(_gen_loop_5_final(node.name, to_keep))
to_keep = []
within_chunk_region = False
node_idx += 1
@ -572,7 +568,7 @@ if CODEGEN_AVAILABLE:
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
# NOTE: we add a variable to distinguish body and ckpt_func
def delete_unused_values(user: Node, body):
def delete_unused_values(user: Node, body, to_keep=[]):
"""
Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage
@ -584,6 +580,9 @@ if CODEGEN_AVAILABLE:
body.append('\n')
return
nodes_to_delete = user_to_last_uses.get(user, [])
for n in nodes_to_delete:
if n.name in to_keep:
nodes_to_delete.remove(n)
if len(nodes_to_delete):
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
body.append(f'; {to_delete_str}\n')
@ -694,4 +693,5 @@ if CODEGEN_AVAILABLE:
{prologue}
{code}"""
print(fn_code)
return PythonCode(fn_code, globals_)

View File

@ -54,6 +54,7 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.T
# test forward
non_fx_out = model(data)
fx_out = gm(data)
print(non_fx_out.shape, fx_out.shape)
assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output"
# test barckward
@ -86,13 +87,13 @@ def _run_offload_codegen(rank):
setattr(node, "activation_offload", [0, True, False])
if node.name == "linear1":
setattr(node, "activation_offload", [0, True, False])
if node.name == "linear2":
setattr(node, "activation_offload", [1, True, True])
if node.name == "linear4":
setattr(node, "activation_offload", [2, False, True])
if node.name == "linear5":
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True)
# if node.name == "linear2":
# setattr(node, "activation_offload", [1, True, True])
# if node.name == "linear4":
# setattr(node, "activation_offload", [2, False, True])
# if node.name == "linear5":
# setattr(node, "activation_checkpoint", [0])
# setattr(node, "activation_offload", True)
gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile()