[fx] added activation checkpointing annotation (#1349)

* [fx] added activation checkpointing annotation

* polish code

* polish code
pull/1350/head
Frank Lee 2 years ago committed by GitHub
parent 051592c64e
commit 05fae1fd56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,11 +8,12 @@ import enum
import inspect
import functools
import operator
from contextlib import contextmanager
from colossalai.fx.tracer.meta_patch import meta_patched_module
import torch
import torch.nn as nn
from torch import Tensor
from torch.fx import Tracer
from torch.fx import Tracer, Node
from torch.fx.graph import Graph
from torch.fx.proxy import Proxy, ParameterProxy
from ..proxy import ColoProxy
@ -55,11 +56,17 @@ class ColoTracer(Tracer):
graph = tracer.trace(model, concrete_args={'y': torch.rand(4, 10)}, meta_args={'x': torch.rand(4, 10, device='meta')})
"""
def __init__(self, *args, **kwargs):
def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tracer_type = TracerType.META
self.proxy_cls = ColoProxy
# whether the tracer will record the usage of torch.utils.checkpoint
self.trace_act_ckpt = trace_act_ckpt
# whether the current tracing occurs within the activation checkpoint functions
self.inside_torch_checkpoint_func = False
self.act_ckpt_region_count = 0
# Feature flag for proxying accesses to buffer values
proxy_buffer_attributes: bool = True
@ -297,7 +304,10 @@ class ColoTracer(Tracer):
self.orig_torch_tensor_methods = [val[1] for val in self.patched_torch_tensor_methods.values()]
try:
self.graph = super().trace(root, concrete_args=concrete_args)
# to track the usage of torch.utils.checkpoint
with self.trace_activation_checkpoint(enabled=self.trace_act_ckpt):
self.graph = super().trace(root, concrete_args=concrete_args)
finally:
# recover the patched methods
for name, (_, orig) in self.patched_torch_tensor_methods.items():
@ -338,6 +348,43 @@ class ColoTracer(Tracer):
return self.graph
@contextmanager
def trace_activation_checkpoint(self, enabled: bool):
if enabled:
orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction
class PatchedCheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
# signal that the current tracing occurs within activaton checkpoint part
self.inside_torch_checkpoint_func = True
out = run_function(*args)
self.inside_torch_checkpoint_func = False
self.act_ckpt_region_count += 1
return out
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError(
"We do not implement the backward pass as we only trace the forward pass.")
# override the checkpoint function
torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction
yield
if enabled:
# recover the checkpoint function upon exit
torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func
def create_node(self, *args, **kwargs) -> Node:
node = super().create_node(*args, **kwargs)
if self.inside_torch_checkpoint_func:
# annotate the activation checkpoint module
setattr(node, 'activation_checkpoint', self.act_ckpt_region_count)
return node
def wrap_tensor_constructor_method(target):
@ -367,7 +414,7 @@ def wrap_tensor_constructor_method(target):
colo_proxy = proxy.tracer.create_proxy("call_function", target, args, kwargs)
if not isinstance(colo_proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
colo_proxy = ColoProxy(fx_proxy.node)
colo_proxy = ColoProxy(proxy.node)
colo_proxy.meta_data = meta_out
return colo_proxy
else:

@ -0,0 +1,62 @@
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
from torch.utils.checkpoint import checkpoint
class MLP(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(4, 4)
self.linear2 = torch.nn.Linear(4, 4)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp_1 = MLP()
self.mlp_2 = MLP()
self.output = torch.nn.Linear(4, 4)
def forward(self, x):
x = checkpoint(self.mlp_1, x)
x = checkpoint(self.mlp_2, x)
x = self.output(x)
return x
def test_activation_checkpoint_annotation():
module = MyModule()
# test tracing with activation checkpoint
tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(module)
gm = GraphModule(module, graph)
for node in gm.graph.nodes:
if node.name in ['mlp_1_linear1', 'mlp_1_linear2']:
assert getattr(node, 'activation_checkpoint', -1) == 0
for node in gm.graph.nodes:
if node.name in ['mlp_2_linear1', 'mlp_2_linear2']:
assert getattr(node, 'activation_checkpoint', -1) == 1
tracer = ColoTracer(trace_act_ckpt=False)
graph = tracer.trace(module)
gm = GraphModule(module, graph)
for node in gm.graph.nodes:
assert not hasattr(node, 'activation_checkpoint')
if __name__ == '__main__':
test_activation_checkpoint_annotation()
Loading…
Cancel
Save