You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/fx/tracer/_meta_trace.py

140 lines
4.7 KiB

import torch
from torch.fx import Graph, Node
from torch.utils._pytree import tree_map
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point()
def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Graph:
"""Trace forward and backward graph with MetaTensor
Args:
module (torch.nn.Module): The target module for tracing.
Returns:
graph (torch.fx.Graph): The computation graph.
Usage:
>>> import torchvision.models as tm
>>> model = tm.alexnet()
>>> graph = meta_trace(model, torch.rand(1000, 3, 224, 224))
>>> graph.print_tabular()
"""
graph = Graph()
namespace = graph._graph_namespace
class MetaProxy(torch.Tensor):
"""
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
"""
_tensor: torch.Tensor
_node: Node
__slots__ = ["_tensor", "_node"]
@staticmethod
def __new__(cls, tensor, fake_device=None, placeholder=False, name=None):
r = torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
strides=tensor.stride(),
storage_offset=tensor.storage_offset(),
dtype=tensor.dtype,
layout=tensor.layout,
device=fake_device if fake_device is not None else tensor.device,
requires_grad=tensor.requires_grad,
) # deceive the frontend for aten selections
r._tensor = tensor
if placeholder:
if name is None:
name = "input"
r._node = graph.create_node(
"placeholder", "placeholder", (graph._root,), name=namespace.create_name(name, tensor)
)
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device("meta"))
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(x):
nonlocal fake_device
if isinstance(x, MetaProxy):
fake_device = x.device
x = x._tensor
# assert not isinstance(x, MetaProxy)
elif isinstance(x, torch.Tensor):
fake_device = x.device
x = x.to(torch.device("meta"))
return x
def get_node(x):
if isinstance(x, torch.Tensor) and not hasattr(x, "_node"):
x = MetaProxy(x, placeholder=True, name="weight")
return x if not hasattr(x, "_node") else x._node
args_node = tree_map(get_node, args)
kwargs_node = tree_map(get_node, kwargs)
node = graph.create_node("call_function", func, args_node, kwargs_node)
if "device" in kwargs:
fake_device = kwargs["device"]
kwargs["device"] = torch.device("meta")
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
x = x.to(torch.device("meta"))
return (
MetaProxy(x, fake_device=fake_device)
if isinstance(x, torch.Tensor) and not hasattr(x, "_tensor")
else x
)
def set_node(x):
x._node = node
out = tree_map(wrap, out)
tree_map(set_node, out)
return out
def wrap(x):
return MetaProxy(x, fake_device=fake_device, placeholder=True) if isinstance(x, torch.Tensor) else x
args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)
out = module(*args, **kwargs)
for tensor in normalize_tuple(out):
if is_autogradable(tensor) and tensor.requires_grad:
grad = (
torch.empty_like(tensor._tensor, device=torch.device("meta"))
if isinstance(tensor, MetaProxy)
else torch.empty_like(tensor, device=torch.device("meta"))
)
torch.autograd.backward(
tensor, MetaProxy(grad, fake_device=tensor.device, placeholder=True), retain_graph=True
)
return graph