mirror of https://github.com/hpcaitech/ColossalAI
[fx] Use colossalai checkpoint and add offload recognition in codegen (#1439)
* [fx] Use colossalai.utils.checkpoint to replace torch.utils.checkpoint for offload activation and add offload annotation recognition in codegen * [fx] Use colossalai.utils.checkpoint to replace torch.utils.checkpoint for offload activation and add offload annotation recognition in codegen * Modification of test and add TODO in codegen * [fx] Modification of colossal ckpt usage * [fx] add gpc.destroy() to test_codegenpull/1450/head
parent
e9460b45c8
commit
5774fe0270
|
@ -99,13 +99,13 @@ def _gen_ckpt_output(output_vars: List[str]) -> str:
|
|||
return f"return {', '.join(output_vars)}"
|
||||
|
||||
|
||||
def _gen_ckpt_usage(label, input_vars, output_vars):
|
||||
def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars):
|
||||
"""
|
||||
Generate the checkpoint function call code text
|
||||
"""
|
||||
outputs = ', '.join(output_vars)
|
||||
inputs = ', '.join(input_vars)
|
||||
return f'{outputs} = torch.utils.checkpoint.checkpoint(checkpoint_{label}, {inputs})'
|
||||
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs})'
|
||||
|
||||
|
||||
def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func):
|
||||
|
@ -155,8 +155,15 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu
|
|||
return_statement = f' {return_statement}\n'
|
||||
body.append(return_statement)
|
||||
|
||||
# we need to check if the checkpoint need to offload the input
|
||||
start_node_idx = start_idx[label]
|
||||
if hasattr(node_list[start_node_idx], 'activation_offload'):
|
||||
activation_offload = node_list[start_node_idx].activation_offload
|
||||
else:
|
||||
activation_offload = False
|
||||
|
||||
# generate checkpoint function call in a new line
|
||||
usage = _gen_ckpt_usage(label, input_vars[label], output_vars[label])
|
||||
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label])
|
||||
usage += '\n'
|
||||
body.append(usage)
|
||||
within_ckpt_region = False
|
||||
|
@ -368,7 +375,11 @@ if codegen_available:
|
|||
for name, value in self.additional_globals():
|
||||
add_global(name, value)
|
||||
|
||||
# as we need colossalai.utils.checkpoint, we need to import colossalai
|
||||
# in forward function
|
||||
# TODO: Remove inline import
|
||||
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
|
||||
prologue = prologue + "\n import colossalai"
|
||||
|
||||
code = ''.join(body)
|
||||
code = '\n'.join(' ' + line for line in code.split('\n'))
|
||||
|
@ -566,9 +577,14 @@ else:
|
|||
orig_args.insert(0, 'self')
|
||||
code = ''.join(body)
|
||||
code = '\n'.join(' ' + line for line in code.split('\n'))
|
||||
|
||||
# as we need colossalai.utils.checkpoint, we need to import colossalai
|
||||
# in forward function
|
||||
# TODO: Remove inline import
|
||||
fn_code = f"""
|
||||
{wrap_stmts}
|
||||
|
||||
def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
|
||||
import colossalai
|
||||
{code}"""
|
||||
return PythonCode(fn_code, globals_)
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
from operator import mod
|
||||
import torch
|
||||
import pytest
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx import ColoTracer
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
|
@ -40,9 +44,17 @@ class MyModule(torch.nn.Module):
|
|||
|
||||
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||
def test_act_ckpt_codegen():
|
||||
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
|
||||
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
|
||||
# build model and run forward
|
||||
model = MyModule()
|
||||
data = torch.rand(4, 4)
|
||||
|
||||
# copy model to cuda
|
||||
model = model.to(device="cuda")
|
||||
data = data.to(device="cuda")
|
||||
|
||||
non_fx_out = model(data)
|
||||
|
||||
# trace the module and replace codegen
|
||||
|
@ -52,14 +64,22 @@ def test_act_ckpt_codegen():
|
|||
graph.set_codegen(codegen)
|
||||
|
||||
# check ops are annotated with ckpt
|
||||
# also annotate the selected node for offloading
|
||||
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1']
|
||||
offload_starts = ['mlp2_linear1']
|
||||
for node in graph.nodes:
|
||||
if node.name in ckpt_nodes:
|
||||
assert hasattr(node, 'activation_checkpoint')
|
||||
|
||||
# assert checkpoint function will be generated
|
||||
# annotate the selected node for offload
|
||||
if node.name in offload_starts:
|
||||
setattr(node, 'activation_offload', True)
|
||||
|
||||
# assert checkpoint function will be generated and
|
||||
# the offload option is correct
|
||||
code = graph.python_code('self').src
|
||||
assert 'checkpoint_0' in code and 'checkpoint_1' in code
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code
|
||||
|
||||
# recompile and verify the outputs are consistent
|
||||
gm = GraphModule(model, graph)
|
||||
|
@ -67,12 +87,22 @@ def test_act_ckpt_codegen():
|
|||
fx_out = gm(data)
|
||||
assert torch.equal(non_fx_out, fx_out)
|
||||
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
|
||||
def test_act_ckpt_python_code_torch11():
|
||||
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
|
||||
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
|
||||
# build model and run forward
|
||||
model = MyModule()
|
||||
data = torch.rand(4, 4)
|
||||
|
||||
# copy model to cuda
|
||||
model = model.to(device="cuda")
|
||||
data = data.to(device="cuda")
|
||||
|
||||
non_fx_out = model(data)
|
||||
|
||||
# trace the module and replace codegen
|
||||
|
@ -84,13 +114,20 @@ def test_act_ckpt_python_code_torch11():
|
|||
|
||||
# check ops are annotated with ckpt
|
||||
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1']
|
||||
offload_starts = ['mlp2_linear1']
|
||||
for node in graph.nodes:
|
||||
if node.name in ckpt_nodes:
|
||||
assert hasattr(node, 'activation_checkpoint')
|
||||
|
||||
# assert checkpoint function will be generated
|
||||
# annotate the selected node for offload
|
||||
if node.name in offload_starts:
|
||||
setattr(node, 'activation_offload', True)
|
||||
|
||||
# assert checkpoint function will be generated and
|
||||
# the offload option is correct
|
||||
code = graph.python_code('self').src
|
||||
assert 'checkpoint_0' in code and 'checkpoint_1' in code
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code
|
||||
|
||||
# recompile and verify the outputs are consistent
|
||||
gm = GraphModule(model, graph)
|
||||
|
@ -98,7 +135,10 @@ def test_act_ckpt_python_code_torch11():
|
|||
fx_out = gm(data)
|
||||
assert torch.equal(non_fx_out, fx_out)
|
||||
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
test_act_ckpt_codegen()
|
||||
test_act_ckpt_python_code_torch11()
|
||||
|
|
Loading…
Reference in New Issue