basic chunk

pull/2364/head
oahzxl 2022-11-02 13:59:48 +08:00
parent 87cddf7e14
commit 78cfe4362b
2 changed files with 41 additions and 40 deletions

View File

@ -46,6 +46,19 @@ def pack_hook_no_input(self, x):
return pack_hook, unpack_hook return pack_hook, unpack_hook
def _gen_loop_5(to_keep):
context = "chunk_result = []\nfor gen_loop_idx in range(4):\n"
context += " chunk_tensor = " + to_keep + "[gen_loop_idx, :]\n"
return context
def _gen_loop_5_final(final_name, to_keep):
context = " chunk_result.append(" + final_name + ")\n"
context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n"
context += final_name + " = chunk_result; chunk_result = None\n"
return context
def _gen_save_tensors_hooks_context(offload_input=True) -> str: def _gen_save_tensors_hooks_context(offload_input=True) -> str:
"""Generate customized saved_tensors_hooks """Generate customized saved_tensors_hooks
@ -410,57 +423,40 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
# this flag is to prevent repeated insert of save tensors # this flag is to prevent repeated insert of save tensors
# hooks definition in ckpt_func # hooks definition in ckpt_func
is_hook_inserted = False
node_idx = 0 node_idx = 0
while 1: to_keep = []
while node_idx < len(node_list):
# break if we finish the processing all the nodes # break if we finish the processing all the nodes
if node_idx >= len(node_list): if node_idx >= len(node_list):
break break
# process ckpt_regions
if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
node_idx += len(ckpt_node_list)
# process node in forward function # process node in forward function
else: else:
node = node_list[node_idx] node = node_list[node_idx]
if node_idx in chunk_starts: if node_idx in chunk_starts:
chunk_label = chunk_labels[chunk_starts.index(node_idx)] # save chunk input var, dont delete it
_, chunk_input, chunk_bar = chunk_label to_keep.extend(node.args[0].name)
within_chunk_region = True within_chunk_region = True
# add for loop
# insert hook functions if needed body.append(_gen_loop_5(to_keep[0]))
if not is_hook_inserted: # change first node's input to new chunked var
pack_hook, unpack_hook = _gen_saved_tensors_hooks() node_args = list(node.args)
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") node_args[0] = 'chunk_tensor'
is_hook_inserted = True
if chunk_input and chunk_bar:
body.append(_gen_save_on_cpu_context())
elif chunk_input:
for par in chunk_inputs[chunk_label[0]]:
body.append(f"setattr({par}, 'offload', True)\n")
body.append(_gen_save_tensors_hooks_context(offload_input=True))
else:
for par in chunk_inputs[chunk_label[0]]:
body.append(f"setattr({par}, 'offload', False)\n")
body.append(_gen_save_tensors_hooks_context(offload_input=False))
if within_chunk_region: if within_chunk_region:
emit_node_func(node, body) emit_node_func(node, body)
body[-1] = ' ' + body[-1] body[-1] = ' ' + body[-1]
delete_unused_value_func(node, body) delete_unused_value_func(node, body, to_keep)
else: else:
emit_node_func(node, body) emit_node_func(node, body)
delete_unused_value_func(node, body) if node_idx not in chunk_inputs:
delete_unused_value_func(node, body, to_keep)
if node_idx in chunk_ends: if node_idx in chunk_ends:
body.append(_gen_loop_5_final(node.name, to_keep))
to_keep = []
within_chunk_region = False within_chunk_region = False
node_idx += 1 node_idx += 1
@ -572,7 +568,7 @@ if CODEGEN_AVAILABLE:
map_arg(node.kwargs, lambda n: register_last_uses(n, node)) map_arg(node.kwargs, lambda n: register_last_uses(n, node))
# NOTE: we add a variable to distinguish body and ckpt_func # NOTE: we add a variable to distinguish body and ckpt_func
def delete_unused_values(user: Node, body): def delete_unused_values(user: Node, body, to_keep=[]):
""" """
Delete values after their last use. This ensures that values that are Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage not used in the remainder of the code are freed and the memory usage
@ -584,6 +580,9 @@ if CODEGEN_AVAILABLE:
body.append('\n') body.append('\n')
return return
nodes_to_delete = user_to_last_uses.get(user, []) nodes_to_delete = user_to_last_uses.get(user, [])
for n in nodes_to_delete:
if n.name in to_keep:
nodes_to_delete.remove(n)
if len(nodes_to_delete): if len(nodes_to_delete):
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
body.append(f'; {to_delete_str}\n') body.append(f'; {to_delete_str}\n')
@ -694,4 +693,5 @@ if CODEGEN_AVAILABLE:
{prologue} {prologue}
{code}""" {code}"""
print(fn_code)
return PythonCode(fn_code, globals_) return PythonCode(fn_code, globals_)

View File

@ -54,6 +54,7 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.T
# test forward # test forward
non_fx_out = model(data) non_fx_out = model(data)
fx_out = gm(data) fx_out = gm(data)
print(non_fx_out.shape, fx_out.shape)
assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output" assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output"
# test barckward # test barckward
@ -86,13 +87,13 @@ def _run_offload_codegen(rank):
setattr(node, "activation_offload", [0, True, False]) setattr(node, "activation_offload", [0, True, False])
if node.name == "linear1": if node.name == "linear1":
setattr(node, "activation_offload", [0, True, False]) setattr(node, "activation_offload", [0, True, False])
if node.name == "linear2": # if node.name == "linear2":
setattr(node, "activation_offload", [1, True, True]) # setattr(node, "activation_offload", [1, True, True])
if node.name == "linear4": # if node.name == "linear4":
setattr(node, "activation_offload", [2, False, True]) # setattr(node, "activation_offload", [2, False, True])
if node.name == "linear5": # if node.name == "linear5":
setattr(node, "activation_checkpoint", [0]) # setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True) # setattr(node, "activation_offload", True)
gm = ColoGraphModule(copy.deepcopy(model), graph) gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile() gm.recompile()