mirror of https://github.com/hpcaitech/ColossalAI
[fx] Modify offload codegen (#1618)
* [fx] modify offload codegen * [fx] remove repeated hook definitions * [fx] modify offload testpull/1617/head^2
parent
9eae855408
commit
d6b01feb66
|
@ -22,8 +22,14 @@ def _gen_saved_tensors_hooks():
|
|||
Generate saved tensors hooks
|
||||
"""
|
||||
|
||||
pack_hook = """def pack_hook(self, x):
|
||||
if getattr(x, "offload", None):
|
||||
pack_hook = """def pack_hook_input(self, x):
|
||||
if getattr(x, "offload", False):
|
||||
return (x.device, x.cpu())
|
||||
else:
|
||||
return x
|
||||
|
||||
def pack_hook_no_input(self, x):
|
||||
if getattr(x, "offload", True):
|
||||
return (x.device, x.cpu())
|
||||
else:
|
||||
return x
|
||||
|
@ -40,12 +46,30 @@ def _gen_saved_tensors_hooks():
|
|||
return pack_hook, unpack_hook
|
||||
|
||||
|
||||
def _gen_save_tensors_hooks_context():
|
||||
"""
|
||||
Generate save tensors hooks context
|
||||
def _gen_save_tensors_hooks_context(offload_input=True) -> str:
|
||||
"""Generate customized saved_tensors_hooks
|
||||
|
||||
Args:
|
||||
offload_input (bool, optional): whether we need offload input, if offload_input=False,
|
||||
we will use self.pack_hook_no_input instead. Defaults to True.
|
||||
|
||||
Returns:
|
||||
str: generated context
|
||||
"""
|
||||
|
||||
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):\n"
|
||||
if offload_input:
|
||||
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):\n"
|
||||
else:
|
||||
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):\n"
|
||||
return context
|
||||
|
||||
|
||||
def _gen_save_on_cpu_context():
|
||||
"""
|
||||
Generate save on cpu context
|
||||
"""
|
||||
|
||||
context = "with torch.autograd.graph.save_on_cpu(pin_memory=True):\n"
|
||||
return context
|
||||
|
||||
|
||||
|
@ -118,6 +142,51 @@ def _find_ckpt_regions(nodes: List[Node]):
|
|||
return ckpt_regions
|
||||
|
||||
|
||||
def _find_offload_regions(nodes: List[Node]):
|
||||
"""This function is to find the offload regions
|
||||
In pofo algorithm, during annotation, we will annotate the offload region with the
|
||||
tuple in the form of (idx, offload_input, offload_bar). idx indicates the offload
|
||||
region's index, offload_input is a bool type indicates whether we need to offload
|
||||
the input, offload_bar is a bool type indicates whether we need to offload all the
|
||||
intermediate x_bars of this region.
|
||||
"""
|
||||
offload_regions = []
|
||||
offload_labels = []
|
||||
start = -1
|
||||
end = -1
|
||||
current_region = None
|
||||
|
||||
for idx, node in enumerate(nodes):
|
||||
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', False), tuple):
|
||||
act_offload_label = node.activation_offload
|
||||
|
||||
if current_region == None:
|
||||
current_region = act_offload_label
|
||||
start = idx
|
||||
offload_labels.append(act_offload_label)
|
||||
|
||||
if act_offload_label != current_region:
|
||||
assert start != -1
|
||||
offload_regions.append((start, idx - 1))
|
||||
offload_labels.append(act_offload_label)
|
||||
current_region = act_offload_label
|
||||
start = idx
|
||||
end = -1
|
||||
|
||||
else:
|
||||
if current_region is not None:
|
||||
end = idx - 1
|
||||
assert start != -1 and end != -1
|
||||
offload_regions.append((start, end))
|
||||
start = end = -1
|
||||
current_region = None
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
return offload_regions, offload_labels
|
||||
|
||||
|
||||
def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
|
||||
"""
|
||||
Generate the checkpoint function definition
|
||||
|
@ -322,8 +391,23 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
|
|||
start_idx = [item[0] for item in ckpt_regions]
|
||||
end_idx = [item[1] for item in ckpt_regions]
|
||||
|
||||
# find the offload regions
|
||||
offload_regions, offload_labels = _find_offload_regions(nodes)
|
||||
offload_starts = [item[0] for item in offload_regions]
|
||||
offload_ends = [item[1] for item in offload_regions]
|
||||
offload_inputs = []
|
||||
offload_outputs = []
|
||||
within_offload_region = False
|
||||
|
||||
node_list = list(nodes)
|
||||
|
||||
# find the input and output var names for each offload region
|
||||
for idx, (start, end) in enumerate(offload_regions):
|
||||
offload_node_list = node_list[start:end + 1]
|
||||
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
|
||||
offload_inputs.append(inputs)
|
||||
offload_outputs.append(outputs)
|
||||
|
||||
# this flag is to prevent repeated insert of save tensors
|
||||
# hooks definition in ckpt_func
|
||||
is_hook_inserted = False
|
||||
|
@ -343,19 +427,31 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
|
|||
else:
|
||||
node = node_list[node_idx]
|
||||
|
||||
# if a node is outside of checkpoint region and want to offload
|
||||
# it's input activation, we will use torch.saved_tensors_hooks
|
||||
# to complete the offload process.
|
||||
if getattr(node, "activation_offload", False):
|
||||
if node_idx in offload_starts:
|
||||
offload_label = offload_labels[offload_starts.index(node_idx)]
|
||||
_, offload_input, offload_bar = offload_label
|
||||
within_offload_region = True
|
||||
|
||||
# insert hook functions if needed
|
||||
if not is_hook_inserted:
|
||||
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
|
||||
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
|
||||
is_hook_inserted = True
|
||||
|
||||
for par in node.all_input_nodes:
|
||||
# annotate the input tensor for pack hook
|
||||
body.append(f"setattr({repr(par)}, 'offload', True)\n")
|
||||
if offload_input and offload_bar:
|
||||
body.append(_gen_save_on_cpu_context())
|
||||
|
||||
body.append(_gen_save_tensors_hooks_context())
|
||||
elif offload_input:
|
||||
for par in offload_inputs[offload_label[0]]:
|
||||
body.append(f"setattr({par}, 'offload', True)\n")
|
||||
body.append(_gen_save_tensors_hooks_context(offload_input=True))
|
||||
|
||||
else:
|
||||
for par in offload_inputs[offload_label[0]]:
|
||||
body.append(f"setattr({par}, 'offload', False)\n")
|
||||
body.append(_gen_save_tensors_hooks_context(offload_input=False))
|
||||
|
||||
if within_offload_region:
|
||||
emit_node_func(node, body)
|
||||
body[-1] = ' ' + body[-1]
|
||||
delete_unused_value_func(node, body)
|
||||
|
@ -363,6 +459,10 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
|
|||
else:
|
||||
emit_node_func(node, body)
|
||||
delete_unused_value_func(node, body)
|
||||
|
||||
if node_idx in offload_ends:
|
||||
within_offload_region = False
|
||||
|
||||
node_idx += 1
|
||||
|
||||
|
||||
|
@ -375,6 +475,14 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
|||
output_vars = []
|
||||
within_ckpt_region = False
|
||||
|
||||
# find the offload regions
|
||||
offload_regions, offload_labels = _find_offload_regions(nodes)
|
||||
offload_starts = [item[0] for item in offload_regions]
|
||||
offload_ends = [item[1] for item in offload_regions]
|
||||
offload_inputs = []
|
||||
offload_outputs = []
|
||||
within_offload_region = False
|
||||
|
||||
node_list = list(nodes)
|
||||
|
||||
# use this variable to avoid inserting hook functions
|
||||
|
@ -388,6 +496,13 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
|||
input_vars.append(inputs)
|
||||
output_vars.append(outputs)
|
||||
|
||||
# find the input and output var names for each offload region
|
||||
for idx, (start, end) in enumerate(offload_regions):
|
||||
offload_node_list = node_list[start:end + 1]
|
||||
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
|
||||
offload_inputs.append(inputs)
|
||||
offload_outputs.append(outputs)
|
||||
|
||||
# append code text to body
|
||||
for idx, node in enumerate(node_list):
|
||||
# if this is the first node of the ckpt region
|
||||
|
@ -398,6 +513,30 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
|||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||
within_ckpt_region = True
|
||||
|
||||
if idx in offload_starts:
|
||||
offload_label = offload_labels[offload_starts.index(idx)]
|
||||
_, offload_input, offload_bar = offload_label
|
||||
within_offload_region = True
|
||||
|
||||
# insert hook functions if needed
|
||||
if not is_hook_inserted:
|
||||
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
|
||||
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
|
||||
is_hook_inserted = True
|
||||
|
||||
if offload_input and offload_bar:
|
||||
body.append(_gen_save_on_cpu_context())
|
||||
|
||||
elif offload_input:
|
||||
for par in offload_inputs[offload_label[0]]:
|
||||
body.append(f"setattr({par}, 'offload', True)\n")
|
||||
body.append(_gen_save_tensors_hooks_context(offload_input=True))
|
||||
|
||||
else:
|
||||
for par in offload_inputs[offload_label[0]]:
|
||||
body.append(f"setattr({par}, 'offload', False)\n")
|
||||
body.append(_gen_save_tensors_hooks_context(offload_input=False))
|
||||
|
||||
# NOTE: emit_node does not emit a string with newline. It depends
|
||||
# on delete_unused_values to append one
|
||||
# NOTE: currently we separate body and ckpt_func definition
|
||||
|
@ -405,27 +544,15 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
|||
emit_node_func(node, ckpt_func)
|
||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||
delete_unused_value_func(node, ckpt_func)
|
||||
|
||||
elif within_offload_region:
|
||||
emit_node_func(node, body)
|
||||
body[-1] = ' ' + body[-1]
|
||||
delete_unused_value_func(node, body)
|
||||
|
||||
else:
|
||||
# if a node is outside of checkpoint region wants to offload
|
||||
# it's input activation, we will use torch.saved_tensors_hooks
|
||||
# to complete the offload process.
|
||||
if getattr(node, "activation_offload", False):
|
||||
if not is_hook_inserted:
|
||||
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
|
||||
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
|
||||
|
||||
for par in node.all_input_nodes:
|
||||
# annotate the input tensor for pack hook
|
||||
body.append(f"setattr({repr(par)}, 'offload', True)\n")
|
||||
|
||||
body.append(_gen_save_tensors_hooks_context())
|
||||
emit_node_func(node, body)
|
||||
body[-1] = ' ' + body[-1]
|
||||
delete_unused_value_func(node, body)
|
||||
|
||||
else:
|
||||
emit_node_func(node, body)
|
||||
delete_unused_value_func(node, body)
|
||||
emit_node_func(node, body)
|
||||
delete_unused_value_func(node, body)
|
||||
|
||||
if idx in end_idx:
|
||||
# if this is the last node of the ckpt region
|
||||
|
@ -470,6 +597,9 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
|||
body.append(usage)
|
||||
within_ckpt_region = False
|
||||
|
||||
if idx in offload_ends:
|
||||
within_offload_region = False
|
||||
|
||||
|
||||
if CODEGEN_AVAILABLE:
|
||||
|
||||
|
|
|
@ -23,18 +23,22 @@ class MyNet(torch.nn.Module):
|
|||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear0 = torch.nn.Linear(4, 4)
|
||||
self.linear1 = torch.nn.Linear(4, 4)
|
||||
self.linear2 = torch.nn.Linear(4, 4)
|
||||
self.linear3 = torch.nn.Linear(4, 4)
|
||||
self.linear4 = torch.nn.Linear(4, 4)
|
||||
self.linear5 = torch.nn.Linear(4, 4)
|
||||
self.linear6 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear0(x)
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
x = self.linear3(x)
|
||||
x = self.linear4(x)
|
||||
x = self.linear5(x)
|
||||
x = self.linear6(x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -78,25 +82,32 @@ def _run_offload_codegen(rank):
|
|||
# also annotate the activation_checkpoint so we could test both types
|
||||
# of input offload
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear0":
|
||||
setattr(node, "activation_offload", (0, True, False))
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_offload", (0, True, False))
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_offload", True)
|
||||
if node.name == "linear3":
|
||||
setattr(node, "activation_offload", True)
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", (1, True, True))
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_offload", (2, False, True))
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", True)
|
||||
|
||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||
gm.recompile()
|
||||
print(gm)
|
||||
|
||||
# assert we have all the components
|
||||
code = graph.python_code("self").src
|
||||
assert "def pack_hook(self, x):" in code and \
|
||||
assert "def pack_hook_input(self, x):" in code and \
|
||||
"def unpack_hook(self, packed):" in code and \
|
||||
"setattr(linear1, 'offload', True)" in code and \
|
||||
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \
|
||||
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code
|
||||
"def pack_hook_no_input(self, x):" in code and \
|
||||
"setattr(x, 'offload', True)" in code and \
|
||||
"setattr(linear3, 'offload', False)" in code and \
|
||||
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
|
||||
"with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
|
||||
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
|
||||
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
|
||||
|
||||
_test_fwd_and_bwd(model, gm, data)
|
||||
gpc.destroy()
|
||||
|
@ -126,25 +137,32 @@ def _run_offload_codegen_torch11(rank):
|
|||
# also annotate the activation_checkpoint so we could test both types
|
||||
# of input offload
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear0":
|
||||
setattr(node, "activation_offload", (0, True, False))
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_offload", (0, True, False))
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_offload", True)
|
||||
if node.name == "linear3":
|
||||
setattr(node, "activation_offload", True)
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", (1, True, True))
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_offload", (2, False, True))
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", True)
|
||||
|
||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||
gm.recompile()
|
||||
print(gm)
|
||||
|
||||
# assert we have all the components
|
||||
code = graph.python_code("self").src
|
||||
assert "def pack_hook(self, x):" in code and \
|
||||
assert "def pack_hook_input(self, x):" in code and \
|
||||
"def unpack_hook(self, packed):" in code and \
|
||||
"setattr(linear1, 'offload', True)" in code and \
|
||||
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \
|
||||
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code
|
||||
"def pack_hook_no_input(self, x):" in code and \
|
||||
"setattr(x, 'offload', True)" in code and \
|
||||
"setattr(linear3, 'offload', False)" in code and \
|
||||
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
|
||||
"with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
|
||||
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
|
||||
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
|
||||
|
||||
_test_fwd_and_bwd(model, gm, data)
|
||||
gpc.destroy()
|
||||
|
|
Loading…
Reference in New Issue