[fx] Modify offload codegen (#1618)

* [fx] modify offload codegen

* [fx] remove repeated hook definitions

* [fx] modify offload test
pull/1617/head^2
Boyuan Yao 2022-09-23 11:04:52 +08:00 committed by GitHub
parent 9eae855408
commit d6b01feb66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 200 additions and 52 deletions

View File

@ -22,8 +22,14 @@ def _gen_saved_tensors_hooks():
Generate saved tensors hooks
"""
pack_hook = """def pack_hook(self, x):
if getattr(x, "offload", None):
pack_hook = """def pack_hook_input(self, x):
if getattr(x, "offload", False):
return (x.device, x.cpu())
else:
return x
def pack_hook_no_input(self, x):
if getattr(x, "offload", True):
return (x.device, x.cpu())
else:
return x
@ -40,12 +46,30 @@ def _gen_saved_tensors_hooks():
return pack_hook, unpack_hook
def _gen_save_tensors_hooks_context():
"""
Generate save tensors hooks context
def _gen_save_tensors_hooks_context(offload_input=True) -> str:
"""Generate customized saved_tensors_hooks
Args:
offload_input (bool, optional): whether we need offload input, if offload_input=False,
we will use self.pack_hook_no_input instead. Defaults to True.
Returns:
str: generated context
"""
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):\n"
if offload_input:
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):\n"
else:
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):\n"
return context
def _gen_save_on_cpu_context():
"""
Generate save on cpu context
"""
context = "with torch.autograd.graph.save_on_cpu(pin_memory=True):\n"
return context
@ -118,6 +142,51 @@ def _find_ckpt_regions(nodes: List[Node]):
return ckpt_regions
def _find_offload_regions(nodes: List[Node]):
"""This function is to find the offload regions
In pofo algorithm, during annotation, we will annotate the offload region with the
tuple in the form of (idx, offload_input, offload_bar). idx indicates the offload
region's index, offload_input is a bool type indicates whether we need to offload
the input, offload_bar is a bool type indicates whether we need to offload all the
intermediate x_bars of this region.
"""
offload_regions = []
offload_labels = []
start = -1
end = -1
current_region = None
for idx, node in enumerate(nodes):
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', False), tuple):
act_offload_label = node.activation_offload
if current_region == None:
current_region = act_offload_label
start = idx
offload_labels.append(act_offload_label)
if act_offload_label != current_region:
assert start != -1
offload_regions.append((start, idx - 1))
offload_labels.append(act_offload_label)
current_region = act_offload_label
start = idx
end = -1
else:
if current_region is not None:
end = idx - 1
assert start != -1 and end != -1
offload_regions.append((start, end))
start = end = -1
current_region = None
else:
pass
return offload_regions, offload_labels
def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
"""
Generate the checkpoint function definition
@ -322,8 +391,23 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
# find the offload regions
offload_regions, offload_labels = _find_offload_regions(nodes)
offload_starts = [item[0] for item in offload_regions]
offload_ends = [item[1] for item in offload_regions]
offload_inputs = []
offload_outputs = []
within_offload_region = False
node_list = list(nodes)
# find the input and output var names for each offload region
for idx, (start, end) in enumerate(offload_regions):
offload_node_list = node_list[start:end + 1]
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
offload_inputs.append(inputs)
offload_outputs.append(outputs)
# this flag is to prevent repeated insert of save tensors
# hooks definition in ckpt_func
is_hook_inserted = False
@ -343,19 +427,31 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
else:
node = node_list[node_idx]
# if a node is outside of checkpoint region and want to offload
# it's input activation, we will use torch.saved_tensors_hooks
# to complete the offload process.
if getattr(node, "activation_offload", False):
if node_idx in offload_starts:
offload_label = offload_labels[offload_starts.index(node_idx)]
_, offload_input, offload_bar = offload_label
within_offload_region = True
# insert hook functions if needed
if not is_hook_inserted:
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
is_hook_inserted = True
for par in node.all_input_nodes:
# annotate the input tensor for pack hook
body.append(f"setattr({repr(par)}, 'offload', True)\n")
if offload_input and offload_bar:
body.append(_gen_save_on_cpu_context())
body.append(_gen_save_tensors_hooks_context())
elif offload_input:
for par in offload_inputs[offload_label[0]]:
body.append(f"setattr({par}, 'offload', True)\n")
body.append(_gen_save_tensors_hooks_context(offload_input=True))
else:
for par in offload_inputs[offload_label[0]]:
body.append(f"setattr({par}, 'offload', False)\n")
body.append(_gen_save_tensors_hooks_context(offload_input=False))
if within_offload_region:
emit_node_func(node, body)
body[-1] = ' ' + body[-1]
delete_unused_value_func(node, body)
@ -363,6 +459,10 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
else:
emit_node_func(node, body)
delete_unused_value_func(node, body)
if node_idx in offload_ends:
within_offload_region = False
node_idx += 1
@ -375,6 +475,14 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
output_vars = []
within_ckpt_region = False
# find the offload regions
offload_regions, offload_labels = _find_offload_regions(nodes)
offload_starts = [item[0] for item in offload_regions]
offload_ends = [item[1] for item in offload_regions]
offload_inputs = []
offload_outputs = []
within_offload_region = False
node_list = list(nodes)
# use this variable to avoid inserting hook functions
@ -388,6 +496,13 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
input_vars.append(inputs)
output_vars.append(outputs)
# find the input and output var names for each offload region
for idx, (start, end) in enumerate(offload_regions):
offload_node_list = node_list[start:end + 1]
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
offload_inputs.append(inputs)
offload_outputs.append(outputs)
# append code text to body
for idx, node in enumerate(node_list):
# if this is the first node of the ckpt region
@ -398,6 +513,30 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
ckpt_func.append(f'{ckpt_fn_def}\n')
within_ckpt_region = True
if idx in offload_starts:
offload_label = offload_labels[offload_starts.index(idx)]
_, offload_input, offload_bar = offload_label
within_offload_region = True
# insert hook functions if needed
if not is_hook_inserted:
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
is_hook_inserted = True
if offload_input and offload_bar:
body.append(_gen_save_on_cpu_context())
elif offload_input:
for par in offload_inputs[offload_label[0]]:
body.append(f"setattr({par}, 'offload', True)\n")
body.append(_gen_save_tensors_hooks_context(offload_input=True))
else:
for par in offload_inputs[offload_label[0]]:
body.append(f"setattr({par}, 'offload', False)\n")
body.append(_gen_save_tensors_hooks_context(offload_input=False))
# NOTE: emit_node does not emit a string with newline. It depends
# on delete_unused_values to append one
# NOTE: currently we separate body and ckpt_func definition
@ -405,27 +544,15 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
elif within_offload_region:
emit_node_func(node, body)
body[-1] = ' ' + body[-1]
delete_unused_value_func(node, body)
else:
# if a node is outside of checkpoint region wants to offload
# it's input activation, we will use torch.saved_tensors_hooks
# to complete the offload process.
if getattr(node, "activation_offload", False):
if not is_hook_inserted:
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
for par in node.all_input_nodes:
# annotate the input tensor for pack hook
body.append(f"setattr({repr(par)}, 'offload', True)\n")
body.append(_gen_save_tensors_hooks_context())
emit_node_func(node, body)
body[-1] = ' ' + body[-1]
delete_unused_value_func(node, body)
else:
emit_node_func(node, body)
delete_unused_value_func(node, body)
emit_node_func(node, body)
delete_unused_value_func(node, body)
if idx in end_idx:
# if this is the last node of the ckpt region
@ -470,6 +597,9 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
body.append(usage)
within_ckpt_region = False
if idx in offload_ends:
within_offload_region = False
if CODEGEN_AVAILABLE:

View File

@ -23,18 +23,22 @@ class MyNet(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear0 = torch.nn.Linear(4, 4)
self.linear1 = torch.nn.Linear(4, 4)
self.linear2 = torch.nn.Linear(4, 4)
self.linear3 = torch.nn.Linear(4, 4)
self.linear4 = torch.nn.Linear(4, 4)
self.linear5 = torch.nn.Linear(4, 4)
self.linear6 = torch.nn.Linear(4, 4)
def forward(self, x):
x = self.linear0(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.linear4(x)
x = self.linear5(x)
x = self.linear6(x)
return x
@ -78,25 +82,32 @@ def _run_offload_codegen(rank):
# also annotate the activation_checkpoint so we could test both types
# of input offload
for node in graph.nodes:
if node.name == "linear0":
setattr(node, "activation_offload", (0, True, False))
if node.name == "linear1":
setattr(node, "activation_offload", (0, True, False))
if node.name == "linear2":
setattr(node, "activation_offload", True)
if node.name == "linear3":
setattr(node, "activation_offload", True)
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", (1, True, True))
if node.name == "linear4":
setattr(node, "activation_offload", (2, False, True))
if node.name == "linear5":
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True)
gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile()
print(gm)
# assert we have all the components
code = graph.python_code("self").src
assert "def pack_hook(self, x):" in code and \
assert "def pack_hook_input(self, x):" in code and \
"def unpack_hook(self, packed):" in code and \
"setattr(linear1, 'offload', True)" in code and \
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code
"def pack_hook_no_input(self, x):" in code and \
"setattr(x, 'offload', True)" in code and \
"setattr(linear3, 'offload', False)" in code and \
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
"with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
_test_fwd_and_bwd(model, gm, data)
gpc.destroy()
@ -126,25 +137,32 @@ def _run_offload_codegen_torch11(rank):
# also annotate the activation_checkpoint so we could test both types
# of input offload
for node in graph.nodes:
if node.name == "linear0":
setattr(node, "activation_offload", (0, True, False))
if node.name == "linear1":
setattr(node, "activation_offload", (0, True, False))
if node.name == "linear2":
setattr(node, "activation_offload", True)
if node.name == "linear3":
setattr(node, "activation_offload", True)
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", (1, True, True))
if node.name == "linear4":
setattr(node, "activation_offload", (2, False, True))
if node.name == "linear5":
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True)
gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile()
print(gm)
# assert we have all the components
code = graph.python_code("self").src
assert "def pack_hook(self, x):" in code and \
assert "def pack_hook_input(self, x):" in code and \
"def unpack_hook(self, packed):" in code and \
"setattr(linear1, 'offload', True)" in code and \
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code
"def pack_hook_no_input(self, x):" in code and \
"setattr(x, 'offload', True)" in code and \
"setattr(linear3, 'offload', False)" in code and \
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
"with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
_test_fwd_and_bwd(model, gm, data)
gpc.destroy()