ColossalAI/colossalai/_analyzer/fx/tracer/symbolic_trace.py

158 lines
5.7 KiB
Python
Raw Normal View History

from typing import Any, Callable, Dict, Optional, Union
import torch
from torch.fx import Tracer
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor
try:
from ..codegen import ActivationCheckpointCodeGen
SUPPORT_ACTIVATION = True
except:
SUPPORT_ACTIVATION = False
from ..graph_module import ColoGraphModule
from .tracer import ColoTracer
def _default_device():
return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
def _current_device(module: torch.nn.Module):
try:
return next(module.parameters()).device
except:
return _default_device()
def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None,
trace_act_ckpt: bool = False,
bias_addition_split: bool = False,
) -> ColoGraphModule:
"""
Traces a ``torch.nn.Module`` or a function and returns a ``GraphModule`` with ``Node``s and ``MetaInfo``
attached to the ``Node``s.
Can be used to trace the usage of ``torch.utils.checkpoint`` and the path of module
(https://github.com/pytorch/examples/blob/main/fx/module_tracer.py).
This tracer is able to trace basic control flow and for loops.
It will split the bias addition into two parts if ``bias_addition_split`` is set to be ``True``.
(See ./bias_addition.py for more details).
Examples:
1. Tracing a ``torch.nn.Module`` with control flow.
.. code-block:: python
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
if x.size(0) > 1:
x = x.sum(dim=0)
return self.linear(x)
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)})
# traced code like:
# def forward(self, x):
# linear_1 = self.linear(x)
# return linear_1
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(2, 2, 2)})
# traced code like:
# def forward(self, x):
# sum = x.sum(dim=0); x = None
# linear = self.linear(sum); sum = None
# return linear
2. Tracing a ``torch.nn.Module`` with ``torch.utils.checkpoint``.
.. code-block:: python
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
def custom_forward(x):
return self.linear(x)
return torch.utils.checkpoint.checkpoint(custom_forward, x)
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, trace_act_ckpt=True)
# traced code like:
# def checkpoint_0(self, x):
# linear = self.linear(x); x = None
# return linear
#
# def forward(self, x):
# linear = torch.utils.checkpoint.checkpoint(checkpoint_0, x); x = None
# return linear
3. Tracing a ``torch.nn.Module`` with ``bias_addition_split``.
.. code-block:: python
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2, bias=True)
def forward(self, x):
return self.linear(x)
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, bias_addition_split=True)
# traced code like:
# def forward(self, x):
# linear_bias = self.linear.bias
# linear_weight = self.linear.weight
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
# add = linear + linear_bias; linear = linear_bias = None
# return add
Args:
root (Union[torch.nn.Module, Callable[..., Any]]): The ``torch.nn.Module`` or function to be traced.
concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be passed to the ``root``.
Defaults to {}.
meta_args (Optional[Dict[str, Any]], optional): Meta arguments to be passed to the ``root``. Mostly used
for tracing control flow. Defaults to {}.
trace_act_ckpt (bool, optional): Whether to trace the usage of ``torch.utils.checkpoint``.
Defaults to False.
bias_addition_split (bool, optional): Whether to split the bias addition into two parts. Defaults to False.
Returns:
ColoGraphModule: A traced ``GraphModule`` that is ready for activation checkpoint ``CodeGen``.
Remarks:
This part of ``symbolic_trace()`` is maintained by Colossal-AI team. If you encountered
any unexpected error during tracing, feel free to raise an issue on Colossal-AI GitHub
repo. We welcome any feedback and contributions to enhance the extensibility of
Colossal-AI.
"""
if meta_args:
device, orig_device = _default_device(), _current_device(root)
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, bias_addition_split=bias_addition_split).trace(
root.to(device), concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)
)
if trace_act_ckpt and SUPPORT_ACTIVATION:
graph.set_codegen(ActivationCheckpointCodeGen())
root.to(orig_device)
else:
graph = Tracer().trace(root, concrete_args=concrete_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)