mirror of https://github.com/hpcaitech/ColossalAI
basic chunk
parent
87cddf7e14
commit
78cfe4362b
|
@ -46,6 +46,19 @@ def pack_hook_no_input(self, x):
|
||||||
return pack_hook, unpack_hook
|
return pack_hook, unpack_hook
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_loop_5(to_keep):
|
||||||
|
context = "chunk_result = []\nfor gen_loop_idx in range(4):\n"
|
||||||
|
context += " chunk_tensor = " + to_keep + "[gen_loop_idx, :]\n"
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_loop_5_final(final_name, to_keep):
|
||||||
|
context = " chunk_result.append(" + final_name + ")\n"
|
||||||
|
context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n"
|
||||||
|
context += final_name + " = chunk_result; chunk_result = None\n"
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
def _gen_save_tensors_hooks_context(offload_input=True) -> str:
|
def _gen_save_tensors_hooks_context(offload_input=True) -> str:
|
||||||
"""Generate customized saved_tensors_hooks
|
"""Generate customized saved_tensors_hooks
|
||||||
|
|
||||||
|
@ -410,57 +423,40 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
||||||
|
|
||||||
# this flag is to prevent repeated insert of save tensors
|
# this flag is to prevent repeated insert of save tensors
|
||||||
# hooks definition in ckpt_func
|
# hooks definition in ckpt_func
|
||||||
is_hook_inserted = False
|
|
||||||
node_idx = 0
|
node_idx = 0
|
||||||
while 1:
|
to_keep = []
|
||||||
|
while node_idx < len(node_list):
|
||||||
# break if we finish the processing all the nodes
|
# break if we finish the processing all the nodes
|
||||||
if node_idx >= len(node_list):
|
if node_idx >= len(node_list):
|
||||||
break
|
break
|
||||||
|
|
||||||
# process ckpt_regions
|
|
||||||
if node_idx in start_idx:
|
|
||||||
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
|
|
||||||
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
|
|
||||||
node_idx += len(ckpt_node_list)
|
|
||||||
|
|
||||||
# process node in forward function
|
# process node in forward function
|
||||||
else:
|
else:
|
||||||
node = node_list[node_idx]
|
node = node_list[node_idx]
|
||||||
|
|
||||||
if node_idx in chunk_starts:
|
if node_idx in chunk_starts:
|
||||||
chunk_label = chunk_labels[chunk_starts.index(node_idx)]
|
# save chunk input var, dont delete it
|
||||||
_, chunk_input, chunk_bar = chunk_label
|
to_keep.extend(node.args[0].name)
|
||||||
within_chunk_region = True
|
within_chunk_region = True
|
||||||
|
# add for loop
|
||||||
# insert hook functions if needed
|
body.append(_gen_loop_5(to_keep[0]))
|
||||||
if not is_hook_inserted:
|
# change first node's input to new chunked var
|
||||||
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
|
node_args = list(node.args)
|
||||||
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
|
node_args[0] = 'chunk_tensor'
|
||||||
is_hook_inserted = True
|
|
||||||
|
|
||||||
if chunk_input and chunk_bar:
|
|
||||||
body.append(_gen_save_on_cpu_context())
|
|
||||||
|
|
||||||
elif chunk_input:
|
|
||||||
for par in chunk_inputs[chunk_label[0]]:
|
|
||||||
body.append(f"setattr({par}, 'offload', True)\n")
|
|
||||||
body.append(_gen_save_tensors_hooks_context(offload_input=True))
|
|
||||||
|
|
||||||
else:
|
|
||||||
for par in chunk_inputs[chunk_label[0]]:
|
|
||||||
body.append(f"setattr({par}, 'offload', False)\n")
|
|
||||||
body.append(_gen_save_tensors_hooks_context(offload_input=False))
|
|
||||||
|
|
||||||
if within_chunk_region:
|
if within_chunk_region:
|
||||||
emit_node_func(node, body)
|
emit_node_func(node, body)
|
||||||
body[-1] = ' ' + body[-1]
|
body[-1] = ' ' + body[-1]
|
||||||
delete_unused_value_func(node, body)
|
delete_unused_value_func(node, body, to_keep)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
emit_node_func(node, body)
|
emit_node_func(node, body)
|
||||||
delete_unused_value_func(node, body)
|
if node_idx not in chunk_inputs:
|
||||||
|
delete_unused_value_func(node, body, to_keep)
|
||||||
|
|
||||||
if node_idx in chunk_ends:
|
if node_idx in chunk_ends:
|
||||||
|
body.append(_gen_loop_5_final(node.name, to_keep))
|
||||||
|
to_keep = []
|
||||||
within_chunk_region = False
|
within_chunk_region = False
|
||||||
|
|
||||||
node_idx += 1
|
node_idx += 1
|
||||||
|
@ -572,7 +568,7 @@ if CODEGEN_AVAILABLE:
|
||||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||||
|
|
||||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||||
def delete_unused_values(user: Node, body):
|
def delete_unused_values(user: Node, body, to_keep=[]):
|
||||||
"""
|
"""
|
||||||
Delete values after their last use. This ensures that values that are
|
Delete values after their last use. This ensures that values that are
|
||||||
not used in the remainder of the code are freed and the memory usage
|
not used in the remainder of the code are freed and the memory usage
|
||||||
|
@ -584,6 +580,9 @@ if CODEGEN_AVAILABLE:
|
||||||
body.append('\n')
|
body.append('\n')
|
||||||
return
|
return
|
||||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||||
|
for n in nodes_to_delete:
|
||||||
|
if n.name in to_keep:
|
||||||
|
nodes_to_delete.remove(n)
|
||||||
if len(nodes_to_delete):
|
if len(nodes_to_delete):
|
||||||
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
|
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
|
||||||
body.append(f'; {to_delete_str}\n')
|
body.append(f'; {to_delete_str}\n')
|
||||||
|
@ -694,4 +693,5 @@ if CODEGEN_AVAILABLE:
|
||||||
|
|
||||||
{prologue}
|
{prologue}
|
||||||
{code}"""
|
{code}"""
|
||||||
|
print(fn_code)
|
||||||
return PythonCode(fn_code, globals_)
|
return PythonCode(fn_code, globals_)
|
||||||
|
|
|
@ -54,6 +54,7 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.T
|
||||||
# test forward
|
# test forward
|
||||||
non_fx_out = model(data)
|
non_fx_out = model(data)
|
||||||
fx_out = gm(data)
|
fx_out = gm(data)
|
||||||
|
print(non_fx_out.shape, fx_out.shape)
|
||||||
assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output"
|
assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output"
|
||||||
|
|
||||||
# test barckward
|
# test barckward
|
||||||
|
@ -86,13 +87,13 @@ def _run_offload_codegen(rank):
|
||||||
setattr(node, "activation_offload", [0, True, False])
|
setattr(node, "activation_offload", [0, True, False])
|
||||||
if node.name == "linear1":
|
if node.name == "linear1":
|
||||||
setattr(node, "activation_offload", [0, True, False])
|
setattr(node, "activation_offload", [0, True, False])
|
||||||
if node.name == "linear2":
|
# if node.name == "linear2":
|
||||||
setattr(node, "activation_offload", [1, True, True])
|
# setattr(node, "activation_offload", [1, True, True])
|
||||||
if node.name == "linear4":
|
# if node.name == "linear4":
|
||||||
setattr(node, "activation_offload", [2, False, True])
|
# setattr(node, "activation_offload", [2, False, True])
|
||||||
if node.name == "linear5":
|
# if node.name == "linear5":
|
||||||
setattr(node, "activation_checkpoint", [0])
|
# setattr(node, "activation_checkpoint", [0])
|
||||||
setattr(node, "activation_offload", True)
|
# setattr(node, "activation_offload", True)
|
||||||
|
|
||||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
Loading…
Reference in New Issue