mirror of https://github.com/hpcaitech/ColossalAI
[fx] Add nested checkpoint in activation checkpoint codegen (#1585)
* [fx] add nested activation_checkpoint codegen * undo algorithms commits * solver * undo some commits * [fx] torch11 add nested activation checkpoint codegen * remove some imports * [fx] add some comments in activation codegen * [fx] codegen instance error fixpull/1550/head
parent
1c9ec32734
commit
f3687e4ee2
|
@ -109,6 +109,209 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
|
|||
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
|
||||
|
||||
|
||||
def _end_of_ckpt(node: Node, check_idx: int) -> bool:
|
||||
"""Check if the node could end the ckpt region
|
||||
|
||||
Args:
|
||||
node (Node): torch.fx.Node
|
||||
check_idx (int): the index of checkpoint level for
|
||||
nested checkpoint
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
if hasattr(node, "activation_checkpoint"):
|
||||
if isinstance(node.activation_checkpoint, list):
|
||||
return node.activation_checkpoint[check_idx] == None
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def _find_nested_ckpt_regions(nodes, check_idx=0):
|
||||
"""
|
||||
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
|
||||
will be list of tuples, each tuple is in the form of (start_index, end_index).
|
||||
"""
|
||||
ckpt_regions = []
|
||||
start = -1
|
||||
end = -1
|
||||
current_region = None
|
||||
|
||||
for idx, node in enumerate(nodes):
|
||||
if hasattr(node, 'activation_checkpoint'):
|
||||
if isinstance(getattr(node, 'activation_checkpoint'), int):
|
||||
act_ckpt_label = node.activation_checkpoint
|
||||
else:
|
||||
act_ckpt_label = node.activation_checkpoint[check_idx]
|
||||
|
||||
# this activation checkpoint label is not set yet
|
||||
# meaning this is the first node of the activation ckpt region
|
||||
if current_region is None:
|
||||
current_region = act_ckpt_label
|
||||
start = idx
|
||||
|
||||
# if activation checkpoint has changed
|
||||
# we restart the tracking
|
||||
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
|
||||
if act_ckpt_label != current_region:
|
||||
assert start != -1
|
||||
ckpt_regions.append((start, idx - 1))
|
||||
current_region = act_ckpt_label
|
||||
start = idx
|
||||
end = -1
|
||||
elif current_region is not None and _end_of_ckpt(node, check_idx):
|
||||
# used to check the case below
|
||||
# node ckpt states = [ckpt, ckpt, non-ckpt]
|
||||
end = idx - 1
|
||||
assert start != -1 and end != -1
|
||||
ckpt_regions.append((start, end))
|
||||
start = end = -1
|
||||
current_region = None
|
||||
else:
|
||||
pass
|
||||
|
||||
if current_region is not None:
|
||||
end = len(nodes) - 1
|
||||
ckpt_regions.append((start, end))
|
||||
return ckpt_regions
|
||||
|
||||
|
||||
def emit_ckpt_func(body,
|
||||
ckpt_func,
|
||||
node_list: List[Node],
|
||||
emit_node_func,
|
||||
delete_unused_value_func,
|
||||
level=0,
|
||||
in_ckpt=False):
|
||||
"""Emit ckpt fuction in nested way
|
||||
|
||||
Args:
|
||||
body: forward code, in recursive calls, this part will be checkpoint
|
||||
functions code
|
||||
ckpt_func: checkpoint functions code, in recursive calls, this part
|
||||
will be a buffer
|
||||
node_list (List[Node]): list of torch.fx.Node
|
||||
emit_node_func: function to emit a node
|
||||
delete_unused_value_func: function to delete unused value
|
||||
level (int, optional): checkpoint level. Defaults to 0.
|
||||
in_ckpt (bool, optional): indicates wether the func is in recursive
|
||||
call. Defaults to False.
|
||||
"""
|
||||
inputs, outputs = _find_input_and_output_nodes(node_list)
|
||||
|
||||
# if the current checkpoint function use int as label, using old generation method
|
||||
if isinstance(node_list[0].activation_checkpoint, int):
|
||||
label = node_list[0].activation_checkpoint
|
||||
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
|
||||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||
for node in node_list:
|
||||
emit_node_func(node, ckpt_func)
|
||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||
delete_unused_value_func(node, ckpt_func)
|
||||
|
||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
|
||||
activation_offload = getattr(node_list[0], "activation_offload", False)
|
||||
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
|
||||
usage += "\n"
|
||||
body.append(usage)
|
||||
|
||||
# use nested ckpt function codegen
|
||||
else:
|
||||
# label given by each layer, e.g. if you are currently at level [0, 1, 1]
|
||||
# the label will be '0_1_1'
|
||||
label = "_".join([str(idx) for idx in node_list[0].activation_checkpoint[:level + 1]])
|
||||
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
|
||||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||
|
||||
# if there is more level to fetch
|
||||
if level + 1 < len(node_list[0].activation_checkpoint):
|
||||
ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
end_idx = [item[1] for item in ckpt_regions]
|
||||
|
||||
# use ckpt_func_buffer to store nested checkpoint functions
|
||||
ckpt_func_buffer = []
|
||||
node_idx = 0
|
||||
while 1:
|
||||
if node_idx >= len(node_list):
|
||||
break
|
||||
|
||||
if node_idx in start_idx:
|
||||
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
|
||||
emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func,
|
||||
delete_unused_value_func, level + 1, True)
|
||||
node_idx += len(ckpt_node_list)
|
||||
|
||||
else:
|
||||
node = node_list[node_idx]
|
||||
emit_node_func(node, ckpt_func)
|
||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||
delete_unused_value_func(node, ckpt_func)
|
||||
node_idx += 1
|
||||
|
||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
|
||||
ckpt_func += ckpt_func_buffer
|
||||
activation_offload = getattr(node_list[0], "activation_offload", False)
|
||||
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
|
||||
if in_ckpt:
|
||||
usage = ' ' + usage
|
||||
body.append(usage)
|
||||
|
||||
# last level
|
||||
else:
|
||||
for node in node_list:
|
||||
emit_node_func(node, ckpt_func)
|
||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||
delete_unused_value_func(node, ckpt_func)
|
||||
|
||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
|
||||
activation_offload = getattr(node_list[0], "activation_offload", False)
|
||||
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
|
||||
if in_ckpt:
|
||||
usage = ' ' + usage
|
||||
body.append(usage)
|
||||
|
||||
|
||||
def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
|
||||
"""Emit code with nested activation checkpoint
|
||||
When we detect some of the node.activation_checkpoint is a List, we will use
|
||||
this function to emit the activation checkpoint codes.
|
||||
|
||||
Args:
|
||||
body: forward code
|
||||
ckpt_func: checkpoint functions code
|
||||
nodes: graph.nodes
|
||||
emit_node_func: function to emit node
|
||||
delete_unused_value_func: function to remove the unused value
|
||||
"""
|
||||
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
end_idx = [item[1] for item in ckpt_regions]
|
||||
|
||||
node_list = list(nodes)
|
||||
|
||||
node_idx = 0
|
||||
while 1:
|
||||
# break if we finish the processing all the nodes
|
||||
if node_idx >= len(node_list):
|
||||
break
|
||||
|
||||
# process ckpt_regions
|
||||
if node_idx in start_idx:
|
||||
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
|
||||
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
|
||||
node_idx += len(ckpt_node_list)
|
||||
|
||||
# process node in forward function
|
||||
else:
|
||||
node = node_list[node_idx]
|
||||
emit_node_func(node, body)
|
||||
delete_unused_value_func(node, body)
|
||||
node_idx += 1
|
||||
|
||||
|
||||
def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
|
||||
# find the activation checkpoint regions
|
||||
ckpt_regions = _find_ckpt_regions(nodes)
|
||||
|
@ -384,7 +587,10 @@ if CODEGEN_AVAILABLE:
|
|||
|
||||
# Modified for activation checkpointing
|
||||
ckpt_func = []
|
||||
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
||||
if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in nodes):
|
||||
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
||||
else:
|
||||
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
|
@ -582,7 +788,10 @@ else:
|
|||
|
||||
# Modified for activation checkpointing
|
||||
ckpt_func = []
|
||||
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
|
||||
if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in self.nodes):
|
||||
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
|
||||
else:
|
||||
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
|
|
|
@ -0,0 +1,153 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx import ColoTracer
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
with_codegen = True
|
||||
except:
|
||||
# fall back to older pytorch version
|
||||
from colossalai.fx.codegen import python_code_with_activation_checkpoint
|
||||
with_codegen = False
|
||||
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(4, 4)
|
||||
self.linear2 = torch.nn.Linear(4, 4)
|
||||
self.linear3 = torch.nn.Linear(4, 4)
|
||||
self.linear4 = torch.nn.Linear(4, 4)
|
||||
self.linear5 = torch.nn.Linear(4, 4)
|
||||
self.linear6 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear6(self.linear5(self.linear4(self.linear3(self.linear2(self.linear1(x))))))
|
||||
|
||||
|
||||
def _run_act_ckpt_codegen(rank):
|
||||
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
|
||||
# build model and run forward
|
||||
model = MyModule()
|
||||
data1 = torch.rand(4, 4)
|
||||
|
||||
# copy model to cuda
|
||||
model = model.to(device="cuda")
|
||||
data1 = data1.to(device="cuda")
|
||||
|
||||
non_fx_out = model(data1)
|
||||
|
||||
# trace the module and replace codegen
|
||||
tracer = ColoTracer(trace_act_ckpt=True)
|
||||
graph = tracer.trace(model)
|
||||
codegen = ActivationCheckpointCodeGen()
|
||||
graph.set_codegen(codegen)
|
||||
|
||||
# annotate nested checkpoint
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_checkpoint", [0, 0, 0])
|
||||
continue
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_checkpoint", [0, 0, None])
|
||||
if node.name == "linear3":
|
||||
setattr(node, "activation_checkpoint", [0, 0, 1])
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_checkpoint", [0, 1, None])
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", 1)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
# assert checkpoint function will be generated and
|
||||
code = graph.python_code('self').src
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code
|
||||
|
||||
# recompile and verify the outputs are consistent
|
||||
fx_out = gm(data1)
|
||||
assert torch.equal(non_fx_out, fx_out)
|
||||
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||
def test_act_ckpt_codegen():
|
||||
mp.spawn(_run_act_ckpt_codegen, nprocs=1)
|
||||
|
||||
|
||||
def _run_act_ckpt_python_code_torch11(rank):
|
||||
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
|
||||
# build model and run forward
|
||||
model = MyModule()
|
||||
data1 = torch.rand(4, 4)
|
||||
|
||||
# copy model to cuda
|
||||
model = model.to(device="cuda")
|
||||
data1 = data1.to(device="cuda")
|
||||
|
||||
non_fx_out = model(data1)
|
||||
|
||||
# trace the module and replace codegen
|
||||
tracer = ColoTracer(trace_act_ckpt=True)
|
||||
graph = tracer.trace(model)
|
||||
codegen = ActivationCheckpointCodeGen()
|
||||
graph.set_codegen(codegen)
|
||||
|
||||
# annotate nested checkpoint
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_checkpoint", [0, 0, 0])
|
||||
continue
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_checkpoint", [0, 0, None])
|
||||
if node.name == "linear3":
|
||||
setattr(node, "activation_checkpoint", [0, 0, 1])
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_checkpoint", [0, 1, None])
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", 1)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
# assert checkpoint function will be generated and
|
||||
code = graph.python_code('self').src
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code
|
||||
|
||||
# recompile and verify the outputs are consistent
|
||||
fx_out = gm(data1)
|
||||
assert torch.equal(non_fx_out, fx_out)
|
||||
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
|
||||
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
|
||||
def test_act_ckpt_python_code_torch11():
|
||||
mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_run_act_ckpt_codegen(rank=0)
|
Loading…
Reference in New Issue