mirror of https://github.com/hpcaitech/ColossalAI
basic chunk
parent
87cddf7e14
commit
78cfe4362b
|
@ -46,6 +46,19 @@ def pack_hook_no_input(self, x):
|
|||
return pack_hook, unpack_hook
|
||||
|
||||
|
||||
def _gen_loop_5(to_keep):
|
||||
context = "chunk_result = []\nfor gen_loop_idx in range(4):\n"
|
||||
context += " chunk_tensor = " + to_keep + "[gen_loop_idx, :]\n"
|
||||
return context
|
||||
|
||||
|
||||
def _gen_loop_5_final(final_name, to_keep):
|
||||
context = " chunk_result.append(" + final_name + ")\n"
|
||||
context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n"
|
||||
context += final_name + " = chunk_result; chunk_result = None\n"
|
||||
return context
|
||||
|
||||
|
||||
def _gen_save_tensors_hooks_context(offload_input=True) -> str:
|
||||
"""Generate customized saved_tensors_hooks
|
||||
|
||||
|
@ -410,57 +423,40 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
|
||||
# this flag is to prevent repeated insert of save tensors
|
||||
# hooks definition in ckpt_func
|
||||
is_hook_inserted = False
|
||||
node_idx = 0
|
||||
while 1:
|
||||
to_keep = []
|
||||
while node_idx < len(node_list):
|
||||
# break if we finish the processing all the nodes
|
||||
if node_idx >= len(node_list):
|
||||
break
|
||||
|
||||
# process ckpt_regions
|
||||
if node_idx in start_idx:
|
||||
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
|
||||
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
|
||||
node_idx += len(ckpt_node_list)
|
||||
|
||||
# process node in forward function
|
||||
else:
|
||||
node = node_list[node_idx]
|
||||
|
||||
if node_idx in chunk_starts:
|
||||
chunk_label = chunk_labels[chunk_starts.index(node_idx)]
|
||||
_, chunk_input, chunk_bar = chunk_label
|
||||
# save chunk input var, dont delete it
|
||||
to_keep.extend(node.args[0].name)
|
||||
within_chunk_region = True
|
||||
|
||||
# insert hook functions if needed
|
||||
if not is_hook_inserted:
|
||||
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
|
||||
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
|
||||
is_hook_inserted = True
|
||||
|
||||
if chunk_input and chunk_bar:
|
||||
body.append(_gen_save_on_cpu_context())
|
||||
|
||||
elif chunk_input:
|
||||
for par in chunk_inputs[chunk_label[0]]:
|
||||
body.append(f"setattr({par}, 'offload', True)\n")
|
||||
body.append(_gen_save_tensors_hooks_context(offload_input=True))
|
||||
|
||||
else:
|
||||
for par in chunk_inputs[chunk_label[0]]:
|
||||
body.append(f"setattr({par}, 'offload', False)\n")
|
||||
body.append(_gen_save_tensors_hooks_context(offload_input=False))
|
||||
# add for loop
|
||||
body.append(_gen_loop_5(to_keep[0]))
|
||||
# change first node's input to new chunked var
|
||||
node_args = list(node.args)
|
||||
node_args[0] = 'chunk_tensor'
|
||||
|
||||
if within_chunk_region:
|
||||
emit_node_func(node, body)
|
||||
body[-1] = ' ' + body[-1]
|
||||
delete_unused_value_func(node, body)
|
||||
delete_unused_value_func(node, body, to_keep)
|
||||
|
||||
else:
|
||||
emit_node_func(node, body)
|
||||
delete_unused_value_func(node, body)
|
||||
if node_idx not in chunk_inputs:
|
||||
delete_unused_value_func(node, body, to_keep)
|
||||
|
||||
if node_idx in chunk_ends:
|
||||
body.append(_gen_loop_5_final(node.name, to_keep))
|
||||
to_keep = []
|
||||
within_chunk_region = False
|
||||
|
||||
node_idx += 1
|
||||
|
@ -572,7 +568,7 @@ if CODEGEN_AVAILABLE:
|
|||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def delete_unused_values(user: Node, body):
|
||||
def delete_unused_values(user: Node, body, to_keep=[]):
|
||||
"""
|
||||
Delete values after their last use. This ensures that values that are
|
||||
not used in the remainder of the code are freed and the memory usage
|
||||
|
@ -584,6 +580,9 @@ if CODEGEN_AVAILABLE:
|
|||
body.append('\n')
|
||||
return
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
for n in nodes_to_delete:
|
||||
if n.name in to_keep:
|
||||
nodes_to_delete.remove(n)
|
||||
if len(nodes_to_delete):
|
||||
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
|
||||
body.append(f'; {to_delete_str}\n')
|
||||
|
@ -694,4 +693,5 @@ if CODEGEN_AVAILABLE:
|
|||
|
||||
{prologue}
|
||||
{code}"""
|
||||
print(fn_code)
|
||||
return PythonCode(fn_code, globals_)
|
||||
|
|
|
@ -54,6 +54,7 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.T
|
|||
# test forward
|
||||
non_fx_out = model(data)
|
||||
fx_out = gm(data)
|
||||
print(non_fx_out.shape, fx_out.shape)
|
||||
assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output"
|
||||
|
||||
# test barckward
|
||||
|
@ -86,13 +87,13 @@ def _run_offload_codegen(rank):
|
|||
setattr(node, "activation_offload", [0, True, False])
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_offload", [1, True, True])
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_offload", [2, False, True])
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", True)
|
||||
# if node.name == "linear2":
|
||||
# setattr(node, "activation_offload", [1, True, True])
|
||||
# if node.name == "linear4":
|
||||
# setattr(node, "activation_offload", [2, False, True])
|
||||
# if node.name == "linear5":
|
||||
# setattr(node, "activation_checkpoint", [0])
|
||||
# setattr(node, "activation_offload", True)
|
||||
|
||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||
gm.recompile()
|
||||
|
|
Loading…
Reference in New Issue