[fx] Add nested checkpoint in activation checkpoint codegen (#1585)

* [fx] add nested activation_checkpoint codegen

* undo algorithms commits

* solver

* undo some commits

* [fx] torch11 add nested activation checkpoint codegen

* remove some imports

* [fx] add some comments in activation codegen

* [fx] codegen instance error fix
pull/1550/head
Boyuan Yao 2022-09-12 20:00:48 +08:00 committed by GitHub
parent 1c9ec32734
commit f3687e4ee2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 364 additions and 2 deletions

View File

@ -109,6 +109,209 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
def _end_of_ckpt(node: Node, check_idx: int) -> bool:
"""Check if the node could end the ckpt region
Args:
node (Node): torch.fx.Node
check_idx (int): the index of checkpoint level for
nested checkpoint
Returns:
bool
"""
if hasattr(node, "activation_checkpoint"):
if isinstance(node.activation_checkpoint, list):
return node.activation_checkpoint[check_idx] == None
else:
return False
else:
return True
def _find_nested_ckpt_regions(nodes, check_idx=0):
"""
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
will be list of tuples, each tuple is in the form of (start_index, end_index).
"""
ckpt_regions = []
start = -1
end = -1
current_region = None
for idx, node in enumerate(nodes):
if hasattr(node, 'activation_checkpoint'):
if isinstance(getattr(node, 'activation_checkpoint'), int):
act_ckpt_label = node.activation_checkpoint
else:
act_ckpt_label = node.activation_checkpoint[check_idx]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
if current_region is None:
current_region = act_ckpt_label
start = idx
# if activation checkpoint has changed
# we restart the tracking
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
if act_ckpt_label != current_region:
assert start != -1
ckpt_regions.append((start, idx - 1))
current_region = act_ckpt_label
start = idx
end = -1
elif current_region is not None and _end_of_ckpt(node, check_idx):
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end = idx - 1
assert start != -1 and end != -1
ckpt_regions.append((start, end))
start = end = -1
current_region = None
else:
pass
if current_region is not None:
end = len(nodes) - 1
ckpt_regions.append((start, end))
return ckpt_regions
def emit_ckpt_func(body,
ckpt_func,
node_list: List[Node],
emit_node_func,
delete_unused_value_func,
level=0,
in_ckpt=False):
"""Emit ckpt fuction in nested way
Args:
body: forward code, in recursive calls, this part will be checkpoint
functions code
ckpt_func: checkpoint functions code, in recursive calls, this part
will be a buffer
node_list (List[Node]): list of torch.fx.Node
emit_node_func: function to emit a node
delete_unused_value_func: function to delete unused value
level (int, optional): checkpoint level. Defaults to 0.
in_ckpt (bool, optional): indicates wether the func is in recursive
call. Defaults to False.
"""
inputs, outputs = _find_input_and_output_nodes(node_list)
# if the current checkpoint function use int as label, using old generation method
if isinstance(node_list[0].activation_checkpoint, int):
label = node_list[0].activation_checkpoint
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')
for node in node_list:
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
activation_offload = getattr(node_list[0], "activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
usage += "\n"
body.append(usage)
# use nested ckpt function codegen
else:
# label given by each layer, e.g. if you are currently at level [0, 1, 1]
# the label will be '0_1_1'
label = "_".join([str(idx) for idx in node_list[0].activation_checkpoint[:level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')
# if there is more level to fetch
if level + 1 < len(node_list[0].activation_checkpoint):
ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
# use ckpt_func_buffer to store nested checkpoint functions
ckpt_func_buffer = []
node_idx = 0
while 1:
if node_idx >= len(node_list):
break
if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func,
delete_unused_value_func, level + 1, True)
node_idx += len(ckpt_node_list)
else:
node = node_list[node_idx]
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
node_idx += 1
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
ckpt_func += ckpt_func_buffer
activation_offload = getattr(node_list[0], "activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
if in_ckpt:
usage = ' ' + usage
body.append(usage)
# last level
else:
for node in node_list:
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
activation_offload = getattr(node_list[0], "activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
if in_ckpt:
usage = ' ' + usage
body.append(usage)
def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
this function to emit the activation checkpoint codes.
Args:
body: forward code
ckpt_func: checkpoint functions code
nodes: graph.nodes
emit_node_func: function to emit node
delete_unused_value_func: function to remove the unused value
"""
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
node_list = list(nodes)
node_idx = 0
while 1:
# break if we finish the processing all the nodes
if node_idx >= len(node_list):
break
# process ckpt_regions
if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
node_idx += len(ckpt_node_list)
# process node in forward function
else:
node = node_list[node_idx]
emit_node_func(node, body)
delete_unused_value_func(node, body)
node_idx += 1
def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
# find the activation checkpoint regions
ckpt_regions = _find_ckpt_regions(nodes)
@ -384,7 +587,10 @@ if CODEGEN_AVAILABLE:
# Modified for activation checkpointing
ckpt_func = []
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in nodes):
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
else:
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
@ -582,7 +788,10 @@ else:
# Modified for activation checkpointing
ckpt_func = []
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in self.nodes):
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
else:
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body

View File

@ -0,0 +1,153 @@
import torch
import torch.nn.functional as F
import pytest
import torch.multiprocessing as mp
from torch.utils.checkpoint import checkpoint
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
with_codegen = True
except:
# fall back to older pytorch version
from colossalai.fx.codegen import python_code_with_activation_checkpoint
with_codegen = False
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(4, 4)
self.linear2 = torch.nn.Linear(4, 4)
self.linear3 = torch.nn.Linear(4, 4)
self.linear4 = torch.nn.Linear(4, 4)
self.linear5 = torch.nn.Linear(4, 4)
self.linear6 = torch.nn.Linear(4, 4)
def forward(self, x):
return self.linear6(self.linear5(self.linear4(self.linear3(self.linear2(self.linear1(x))))))
def _run_act_ckpt_codegen(rank):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
# build model and run forward
model = MyModule()
data1 = torch.rand(4, 4)
# copy model to cuda
model = model.to(device="cuda")
data1 = data1.to(device="cuda")
non_fx_out = model(data1)
# trace the module and replace codegen
tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(model)
codegen = ActivationCheckpointCodeGen()
graph.set_codegen(codegen)
# annotate nested checkpoint
for node in graph.nodes:
if node.name == "linear1":
setattr(node, "activation_checkpoint", [0, 0, 0])
continue
if node.name == "linear2":
setattr(node, "activation_checkpoint", [0, 0, None])
if node.name == "linear3":
setattr(node, "activation_checkpoint", [0, 0, 1])
if node.name == "linear4":
setattr(node, "activation_checkpoint", [0, 1, None])
if node.name == "linear5":
setattr(node, "activation_checkpoint", 1)
gm = ColoGraphModule(model, graph)
gm.recompile()
# assert checkpoint function will be generated and
code = graph.python_code('self').src
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code
# recompile and verify the outputs are consistent
fx_out = gm(data1)
assert torch.equal(non_fx_out, fx_out)
gpc.destroy()
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_act_ckpt_codegen():
mp.spawn(_run_act_ckpt_codegen, nprocs=1)
def _run_act_ckpt_python_code_torch11(rank):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
# build model and run forward
model = MyModule()
data1 = torch.rand(4, 4)
# copy model to cuda
model = model.to(device="cuda")
data1 = data1.to(device="cuda")
non_fx_out = model(data1)
# trace the module and replace codegen
tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(model)
codegen = ActivationCheckpointCodeGen()
graph.set_codegen(codegen)
# annotate nested checkpoint
for node in graph.nodes:
if node.name == "linear1":
setattr(node, "activation_checkpoint", [0, 0, 0])
continue
if node.name == "linear2":
setattr(node, "activation_checkpoint", [0, 0, None])
if node.name == "linear3":
setattr(node, "activation_checkpoint", [0, 0, 1])
if node.name == "linear4":
setattr(node, "activation_checkpoint", [0, 1, None])
if node.name == "linear5":
setattr(node, "activation_checkpoint", 1)
gm = ColoGraphModule(model, graph)
gm.recompile()
# assert checkpoint function will be generated and
code = graph.python_code('self').src
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code
# recompile and verify the outputs are consistent
fx_out = gm(data1)
assert torch.equal(non_fx_out, fx_out)
gpc.destroy()
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
def test_act_ckpt_python_code_torch11():
mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1)
if __name__ == '__main__':
_run_act_ckpt_codegen(rank=0)