mirror of https://github.com/hpcaitech/ColossalAI
[fx] Modify offload codegen (#1618)
* [fx] modify offload codegen * [fx] remove repeated hook definitions * [fx] modify offload testpull/1617/head^2
parent
9eae855408
commit
d6b01feb66
|
@ -22,8 +22,14 @@ def _gen_saved_tensors_hooks():
|
||||||
Generate saved tensors hooks
|
Generate saved tensors hooks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pack_hook = """def pack_hook(self, x):
|
pack_hook = """def pack_hook_input(self, x):
|
||||||
if getattr(x, "offload", None):
|
if getattr(x, "offload", False):
|
||||||
|
return (x.device, x.cpu())
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def pack_hook_no_input(self, x):
|
||||||
|
if getattr(x, "offload", True):
|
||||||
return (x.device, x.cpu())
|
return (x.device, x.cpu())
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
@ -40,12 +46,30 @@ def _gen_saved_tensors_hooks():
|
||||||
return pack_hook, unpack_hook
|
return pack_hook, unpack_hook
|
||||||
|
|
||||||
|
|
||||||
def _gen_save_tensors_hooks_context():
|
def _gen_save_tensors_hooks_context(offload_input=True) -> str:
|
||||||
"""
|
"""Generate customized saved_tensors_hooks
|
||||||
Generate save tensors hooks context
|
|
||||||
|
Args:
|
||||||
|
offload_input (bool, optional): whether we need offload input, if offload_input=False,
|
||||||
|
we will use self.pack_hook_no_input instead. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: generated context
|
||||||
"""
|
"""
|
||||||
|
|
||||||
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):\n"
|
if offload_input:
|
||||||
|
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):\n"
|
||||||
|
else:
|
||||||
|
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):\n"
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_save_on_cpu_context():
|
||||||
|
"""
|
||||||
|
Generate save on cpu context
|
||||||
|
"""
|
||||||
|
|
||||||
|
context = "with torch.autograd.graph.save_on_cpu(pin_memory=True):\n"
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
@ -118,6 +142,51 @@ def _find_ckpt_regions(nodes: List[Node]):
|
||||||
return ckpt_regions
|
return ckpt_regions
|
||||||
|
|
||||||
|
|
||||||
|
def _find_offload_regions(nodes: List[Node]):
|
||||||
|
"""This function is to find the offload regions
|
||||||
|
In pofo algorithm, during annotation, we will annotate the offload region with the
|
||||||
|
tuple in the form of (idx, offload_input, offload_bar). idx indicates the offload
|
||||||
|
region's index, offload_input is a bool type indicates whether we need to offload
|
||||||
|
the input, offload_bar is a bool type indicates whether we need to offload all the
|
||||||
|
intermediate x_bars of this region.
|
||||||
|
"""
|
||||||
|
offload_regions = []
|
||||||
|
offload_labels = []
|
||||||
|
start = -1
|
||||||
|
end = -1
|
||||||
|
current_region = None
|
||||||
|
|
||||||
|
for idx, node in enumerate(nodes):
|
||||||
|
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', False), tuple):
|
||||||
|
act_offload_label = node.activation_offload
|
||||||
|
|
||||||
|
if current_region == None:
|
||||||
|
current_region = act_offload_label
|
||||||
|
start = idx
|
||||||
|
offload_labels.append(act_offload_label)
|
||||||
|
|
||||||
|
if act_offload_label != current_region:
|
||||||
|
assert start != -1
|
||||||
|
offload_regions.append((start, idx - 1))
|
||||||
|
offload_labels.append(act_offload_label)
|
||||||
|
current_region = act_offload_label
|
||||||
|
start = idx
|
||||||
|
end = -1
|
||||||
|
|
||||||
|
else:
|
||||||
|
if current_region is not None:
|
||||||
|
end = idx - 1
|
||||||
|
assert start != -1 and end != -1
|
||||||
|
offload_regions.append((start, end))
|
||||||
|
start = end = -1
|
||||||
|
current_region = None
|
||||||
|
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return offload_regions, offload_labels
|
||||||
|
|
||||||
|
|
||||||
def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
|
def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
|
||||||
"""
|
"""
|
||||||
Generate the checkpoint function definition
|
Generate the checkpoint function definition
|
||||||
|
@ -322,8 +391,23 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
|
||||||
start_idx = [item[0] for item in ckpt_regions]
|
start_idx = [item[0] for item in ckpt_regions]
|
||||||
end_idx = [item[1] for item in ckpt_regions]
|
end_idx = [item[1] for item in ckpt_regions]
|
||||||
|
|
||||||
|
# find the offload regions
|
||||||
|
offload_regions, offload_labels = _find_offload_regions(nodes)
|
||||||
|
offload_starts = [item[0] for item in offload_regions]
|
||||||
|
offload_ends = [item[1] for item in offload_regions]
|
||||||
|
offload_inputs = []
|
||||||
|
offload_outputs = []
|
||||||
|
within_offload_region = False
|
||||||
|
|
||||||
node_list = list(nodes)
|
node_list = list(nodes)
|
||||||
|
|
||||||
|
# find the input and output var names for each offload region
|
||||||
|
for idx, (start, end) in enumerate(offload_regions):
|
||||||
|
offload_node_list = node_list[start:end + 1]
|
||||||
|
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
|
||||||
|
offload_inputs.append(inputs)
|
||||||
|
offload_outputs.append(outputs)
|
||||||
|
|
||||||
# this flag is to prevent repeated insert of save tensors
|
# this flag is to prevent repeated insert of save tensors
|
||||||
# hooks definition in ckpt_func
|
# hooks definition in ckpt_func
|
||||||
is_hook_inserted = False
|
is_hook_inserted = False
|
||||||
|
@ -343,19 +427,31 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
|
||||||
else:
|
else:
|
||||||
node = node_list[node_idx]
|
node = node_list[node_idx]
|
||||||
|
|
||||||
# if a node is outside of checkpoint region and want to offload
|
if node_idx in offload_starts:
|
||||||
# it's input activation, we will use torch.saved_tensors_hooks
|
offload_label = offload_labels[offload_starts.index(node_idx)]
|
||||||
# to complete the offload process.
|
_, offload_input, offload_bar = offload_label
|
||||||
if getattr(node, "activation_offload", False):
|
within_offload_region = True
|
||||||
|
|
||||||
|
# insert hook functions if needed
|
||||||
if not is_hook_inserted:
|
if not is_hook_inserted:
|
||||||
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
|
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
|
||||||
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
|
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
|
||||||
|
is_hook_inserted = True
|
||||||
|
|
||||||
for par in node.all_input_nodes:
|
if offload_input and offload_bar:
|
||||||
# annotate the input tensor for pack hook
|
body.append(_gen_save_on_cpu_context())
|
||||||
body.append(f"setattr({repr(par)}, 'offload', True)\n")
|
|
||||||
|
|
||||||
body.append(_gen_save_tensors_hooks_context())
|
elif offload_input:
|
||||||
|
for par in offload_inputs[offload_label[0]]:
|
||||||
|
body.append(f"setattr({par}, 'offload', True)\n")
|
||||||
|
body.append(_gen_save_tensors_hooks_context(offload_input=True))
|
||||||
|
|
||||||
|
else:
|
||||||
|
for par in offload_inputs[offload_label[0]]:
|
||||||
|
body.append(f"setattr({par}, 'offload', False)\n")
|
||||||
|
body.append(_gen_save_tensors_hooks_context(offload_input=False))
|
||||||
|
|
||||||
|
if within_offload_region:
|
||||||
emit_node_func(node, body)
|
emit_node_func(node, body)
|
||||||
body[-1] = ' ' + body[-1]
|
body[-1] = ' ' + body[-1]
|
||||||
delete_unused_value_func(node, body)
|
delete_unused_value_func(node, body)
|
||||||
|
@ -363,6 +459,10 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
|
||||||
else:
|
else:
|
||||||
emit_node_func(node, body)
|
emit_node_func(node, body)
|
||||||
delete_unused_value_func(node, body)
|
delete_unused_value_func(node, body)
|
||||||
|
|
||||||
|
if node_idx in offload_ends:
|
||||||
|
within_offload_region = False
|
||||||
|
|
||||||
node_idx += 1
|
node_idx += 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -375,6 +475,14 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
||||||
output_vars = []
|
output_vars = []
|
||||||
within_ckpt_region = False
|
within_ckpt_region = False
|
||||||
|
|
||||||
|
# find the offload regions
|
||||||
|
offload_regions, offload_labels = _find_offload_regions(nodes)
|
||||||
|
offload_starts = [item[0] for item in offload_regions]
|
||||||
|
offload_ends = [item[1] for item in offload_regions]
|
||||||
|
offload_inputs = []
|
||||||
|
offload_outputs = []
|
||||||
|
within_offload_region = False
|
||||||
|
|
||||||
node_list = list(nodes)
|
node_list = list(nodes)
|
||||||
|
|
||||||
# use this variable to avoid inserting hook functions
|
# use this variable to avoid inserting hook functions
|
||||||
|
@ -388,6 +496,13 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
||||||
input_vars.append(inputs)
|
input_vars.append(inputs)
|
||||||
output_vars.append(outputs)
|
output_vars.append(outputs)
|
||||||
|
|
||||||
|
# find the input and output var names for each offload region
|
||||||
|
for idx, (start, end) in enumerate(offload_regions):
|
||||||
|
offload_node_list = node_list[start:end + 1]
|
||||||
|
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
|
||||||
|
offload_inputs.append(inputs)
|
||||||
|
offload_outputs.append(outputs)
|
||||||
|
|
||||||
# append code text to body
|
# append code text to body
|
||||||
for idx, node in enumerate(node_list):
|
for idx, node in enumerate(node_list):
|
||||||
# if this is the first node of the ckpt region
|
# if this is the first node of the ckpt region
|
||||||
|
@ -398,6 +513,30 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
||||||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||||
within_ckpt_region = True
|
within_ckpt_region = True
|
||||||
|
|
||||||
|
if idx in offload_starts:
|
||||||
|
offload_label = offload_labels[offload_starts.index(idx)]
|
||||||
|
_, offload_input, offload_bar = offload_label
|
||||||
|
within_offload_region = True
|
||||||
|
|
||||||
|
# insert hook functions if needed
|
||||||
|
if not is_hook_inserted:
|
||||||
|
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
|
||||||
|
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
|
||||||
|
is_hook_inserted = True
|
||||||
|
|
||||||
|
if offload_input and offload_bar:
|
||||||
|
body.append(_gen_save_on_cpu_context())
|
||||||
|
|
||||||
|
elif offload_input:
|
||||||
|
for par in offload_inputs[offload_label[0]]:
|
||||||
|
body.append(f"setattr({par}, 'offload', True)\n")
|
||||||
|
body.append(_gen_save_tensors_hooks_context(offload_input=True))
|
||||||
|
|
||||||
|
else:
|
||||||
|
for par in offload_inputs[offload_label[0]]:
|
||||||
|
body.append(f"setattr({par}, 'offload', False)\n")
|
||||||
|
body.append(_gen_save_tensors_hooks_context(offload_input=False))
|
||||||
|
|
||||||
# NOTE: emit_node does not emit a string with newline. It depends
|
# NOTE: emit_node does not emit a string with newline. It depends
|
||||||
# on delete_unused_values to append one
|
# on delete_unused_values to append one
|
||||||
# NOTE: currently we separate body and ckpt_func definition
|
# NOTE: currently we separate body and ckpt_func definition
|
||||||
|
@ -405,27 +544,15 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
||||||
emit_node_func(node, ckpt_func)
|
emit_node_func(node, ckpt_func)
|
||||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||||
delete_unused_value_func(node, ckpt_func)
|
delete_unused_value_func(node, ckpt_func)
|
||||||
|
|
||||||
|
elif within_offload_region:
|
||||||
|
emit_node_func(node, body)
|
||||||
|
body[-1] = ' ' + body[-1]
|
||||||
|
delete_unused_value_func(node, body)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# if a node is outside of checkpoint region wants to offload
|
emit_node_func(node, body)
|
||||||
# it's input activation, we will use torch.saved_tensors_hooks
|
delete_unused_value_func(node, body)
|
||||||
# to complete the offload process.
|
|
||||||
if getattr(node, "activation_offload", False):
|
|
||||||
if not is_hook_inserted:
|
|
||||||
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
|
|
||||||
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
|
|
||||||
|
|
||||||
for par in node.all_input_nodes:
|
|
||||||
# annotate the input tensor for pack hook
|
|
||||||
body.append(f"setattr({repr(par)}, 'offload', True)\n")
|
|
||||||
|
|
||||||
body.append(_gen_save_tensors_hooks_context())
|
|
||||||
emit_node_func(node, body)
|
|
||||||
body[-1] = ' ' + body[-1]
|
|
||||||
delete_unused_value_func(node, body)
|
|
||||||
|
|
||||||
else:
|
|
||||||
emit_node_func(node, body)
|
|
||||||
delete_unused_value_func(node, body)
|
|
||||||
|
|
||||||
if idx in end_idx:
|
if idx in end_idx:
|
||||||
# if this is the last node of the ckpt region
|
# if this is the last node of the ckpt region
|
||||||
|
@ -470,6 +597,9 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
||||||
body.append(usage)
|
body.append(usage)
|
||||||
within_ckpt_region = False
|
within_ckpt_region = False
|
||||||
|
|
||||||
|
if idx in offload_ends:
|
||||||
|
within_offload_region = False
|
||||||
|
|
||||||
|
|
||||||
if CODEGEN_AVAILABLE:
|
if CODEGEN_AVAILABLE:
|
||||||
|
|
||||||
|
|
|
@ -23,18 +23,22 @@ class MyNet(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.linear0 = torch.nn.Linear(4, 4)
|
||||||
self.linear1 = torch.nn.Linear(4, 4)
|
self.linear1 = torch.nn.Linear(4, 4)
|
||||||
self.linear2 = torch.nn.Linear(4, 4)
|
self.linear2 = torch.nn.Linear(4, 4)
|
||||||
self.linear3 = torch.nn.Linear(4, 4)
|
self.linear3 = torch.nn.Linear(4, 4)
|
||||||
self.linear4 = torch.nn.Linear(4, 4)
|
self.linear4 = torch.nn.Linear(4, 4)
|
||||||
self.linear5 = torch.nn.Linear(4, 4)
|
self.linear5 = torch.nn.Linear(4, 4)
|
||||||
|
self.linear6 = torch.nn.Linear(4, 4)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
x = self.linear0(x)
|
||||||
x = self.linear1(x)
|
x = self.linear1(x)
|
||||||
x = self.linear2(x)
|
x = self.linear2(x)
|
||||||
x = self.linear3(x)
|
x = self.linear3(x)
|
||||||
x = self.linear4(x)
|
x = self.linear4(x)
|
||||||
x = self.linear5(x)
|
x = self.linear5(x)
|
||||||
|
x = self.linear6(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,25 +82,32 @@ def _run_offload_codegen(rank):
|
||||||
# also annotate the activation_checkpoint so we could test both types
|
# also annotate the activation_checkpoint so we could test both types
|
||||||
# of input offload
|
# of input offload
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
|
if node.name == "linear0":
|
||||||
|
setattr(node, "activation_offload", (0, True, False))
|
||||||
|
if node.name == "linear1":
|
||||||
|
setattr(node, "activation_offload", (0, True, False))
|
||||||
if node.name == "linear2":
|
if node.name == "linear2":
|
||||||
setattr(node, "activation_offload", True)
|
setattr(node, "activation_offload", (1, True, True))
|
||||||
if node.name == "linear3":
|
|
||||||
setattr(node, "activation_offload", True)
|
|
||||||
setattr(node, "activation_checkpoint", [0])
|
|
||||||
if node.name == "linear4":
|
if node.name == "linear4":
|
||||||
|
setattr(node, "activation_offload", (2, False, True))
|
||||||
|
if node.name == "linear5":
|
||||||
setattr(node, "activation_checkpoint", [0])
|
setattr(node, "activation_checkpoint", [0])
|
||||||
|
setattr(node, "activation_offload", True)
|
||||||
|
|
||||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
print(gm)
|
|
||||||
|
|
||||||
# assert we have all the components
|
# assert we have all the components
|
||||||
code = graph.python_code("self").src
|
code = graph.python_code("self").src
|
||||||
assert "def pack_hook(self, x):" in code and \
|
assert "def pack_hook_input(self, x):" in code and \
|
||||||
"def unpack_hook(self, packed):" in code and \
|
"def unpack_hook(self, packed):" in code and \
|
||||||
"setattr(linear1, 'offload', True)" in code and \
|
"def pack_hook_no_input(self, x):" in code and \
|
||||||
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \
|
"setattr(x, 'offload', True)" in code and \
|
||||||
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code
|
"setattr(linear3, 'offload', False)" in code and \
|
||||||
|
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
|
||||||
|
"with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
|
||||||
|
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
|
||||||
|
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
|
||||||
|
|
||||||
_test_fwd_and_bwd(model, gm, data)
|
_test_fwd_and_bwd(model, gm, data)
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
|
@ -126,25 +137,32 @@ def _run_offload_codegen_torch11(rank):
|
||||||
# also annotate the activation_checkpoint so we could test both types
|
# also annotate the activation_checkpoint so we could test both types
|
||||||
# of input offload
|
# of input offload
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
|
if node.name == "linear0":
|
||||||
|
setattr(node, "activation_offload", (0, True, False))
|
||||||
|
if node.name == "linear1":
|
||||||
|
setattr(node, "activation_offload", (0, True, False))
|
||||||
if node.name == "linear2":
|
if node.name == "linear2":
|
||||||
setattr(node, "activation_offload", True)
|
setattr(node, "activation_offload", (1, True, True))
|
||||||
if node.name == "linear3":
|
|
||||||
setattr(node, "activation_offload", True)
|
|
||||||
setattr(node, "activation_checkpoint", [0])
|
|
||||||
if node.name == "linear4":
|
if node.name == "linear4":
|
||||||
|
setattr(node, "activation_offload", (2, False, True))
|
||||||
|
if node.name == "linear5":
|
||||||
setattr(node, "activation_checkpoint", [0])
|
setattr(node, "activation_checkpoint", [0])
|
||||||
|
setattr(node, "activation_offload", True)
|
||||||
|
|
||||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
print(gm)
|
|
||||||
|
|
||||||
# assert we have all the components
|
# assert we have all the components
|
||||||
code = graph.python_code("self").src
|
code = graph.python_code("self").src
|
||||||
assert "def pack_hook(self, x):" in code and \
|
assert "def pack_hook_input(self, x):" in code and \
|
||||||
"def unpack_hook(self, packed):" in code and \
|
"def unpack_hook(self, packed):" in code and \
|
||||||
"setattr(linear1, 'offload', True)" in code and \
|
"def pack_hook_no_input(self, x):" in code and \
|
||||||
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \
|
"setattr(x, 'offload', True)" in code and \
|
||||||
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code
|
"setattr(linear3, 'offload', False)" in code and \
|
||||||
|
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
|
||||||
|
"with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
|
||||||
|
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
|
||||||
|
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
|
||||||
|
|
||||||
_test_fwd_and_bwd(model, gm, data)
|
_test_fwd_and_bwd(model, gm, data)
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
|
|
Loading…
Reference in New Issue