diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py index 1415e2f9d..4b99f4154 100644 --- a/colossalai/fx/tracer/tracer.py +++ b/colossalai/fx/tracer/tracer.py @@ -8,11 +8,12 @@ import enum import inspect import functools import operator +from contextlib import contextmanager from colossalai.fx.tracer.meta_patch import meta_patched_module import torch import torch.nn as nn from torch import Tensor -from torch.fx import Tracer +from torch.fx import Tracer, Node from torch.fx.graph import Graph from torch.fx.proxy import Proxy, ParameterProxy from ..proxy import ColoProxy @@ -55,11 +56,17 @@ class ColoTracer(Tracer): graph = tracer.trace(model, concrete_args={'y': torch.rand(4, 10)}, meta_args={'x': torch.rand(4, 10, device='meta')}) """ - def __init__(self, *args, **kwargs): + def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs): super().__init__(*args, **kwargs) self.tracer_type = TracerType.META self.proxy_cls = ColoProxy + # whether the tracer will record the usage of torch.utils.checkpoint + self.trace_act_ckpt = trace_act_ckpt + # whether the current tracing occurs within the activation checkpoint functions + self.inside_torch_checkpoint_func = False + self.act_ckpt_region_count = 0 + # Feature flag for proxying accesses to buffer values proxy_buffer_attributes: bool = True @@ -297,7 +304,10 @@ class ColoTracer(Tracer): self.orig_torch_tensor_methods = [val[1] for val in self.patched_torch_tensor_methods.values()] try: - self.graph = super().trace(root, concrete_args=concrete_args) + # to track the usage of torch.utils.checkpoint + with self.trace_activation_checkpoint(enabled=self.trace_act_ckpt): + self.graph = super().trace(root, concrete_args=concrete_args) + finally: # recover the patched methods for name, (_, orig) in self.patched_torch_tensor_methods.items(): @@ -338,6 +348,43 @@ class ColoTracer(Tracer): return self.graph + @contextmanager + def trace_activation_checkpoint(self, enabled: bool): + if enabled: + orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction + + class PatchedCheckpointFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, run_function, preserve_rng_state, *args): + # signal that the current tracing occurs within activaton checkpoint part + self.inside_torch_checkpoint_func = True + out = run_function(*args) + self.inside_torch_checkpoint_func = False + self.act_ckpt_region_count += 1 + return out + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + raise NotImplementedError( + "We do not implement the backward pass as we only trace the forward pass.") + + # override the checkpoint function + torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction + yield + + if enabled: + # recover the checkpoint function upon exit + torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func + + def create_node(self, *args, **kwargs) -> Node: + node = super().create_node(*args, **kwargs) + + if self.inside_torch_checkpoint_func: + # annotate the activation checkpoint module + setattr(node, 'activation_checkpoint', self.act_ckpt_region_count) + return node + def wrap_tensor_constructor_method(target): @@ -367,7 +414,7 @@ def wrap_tensor_constructor_method(target): colo_proxy = proxy.tracer.create_proxy("call_function", target, args, kwargs) if not isinstance(colo_proxy, ColoProxy): meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs) - colo_proxy = ColoProxy(fx_proxy.node) + colo_proxy = ColoProxy(proxy.node) colo_proxy.meta_data = meta_out return colo_proxy else: diff --git a/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py b/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py new file mode 100644 index 000000000..3fd39b393 --- /dev/null +++ b/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn +from colossalai.fx import ColoTracer +from torch.fx import GraphModule +from torch.utils.checkpoint import checkpoint + + +class MLP(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +# Simple module for demonstration +class MyModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.mlp_1 = MLP() + self.mlp_2 = MLP() + self.output = torch.nn.Linear(4, 4) + + def forward(self, x): + x = checkpoint(self.mlp_1, x) + x = checkpoint(self.mlp_2, x) + x = self.output(x) + return x + + +def test_activation_checkpoint_annotation(): + module = MyModule() + + # test tracing with activation checkpoint + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(module) + gm = GraphModule(module, graph) + + for node in gm.graph.nodes: + if node.name in ['mlp_1_linear1', 'mlp_1_linear2']: + assert getattr(node, 'activation_checkpoint', -1) == 0 + + for node in gm.graph.nodes: + if node.name in ['mlp_2_linear1', 'mlp_2_linear2']: + assert getattr(node, 'activation_checkpoint', -1) == 1 + + tracer = ColoTracer(trace_act_ckpt=False) + graph = tracer.trace(module) + gm = GraphModule(module, graph) + + for node in gm.graph.nodes: + assert not hasattr(node, 'activation_checkpoint') + + +if __name__ == '__main__': + test_activation_checkpoint_annotation()