2022-09-05 04:10:09 +00:00
|
|
|
import torch
|
2022-10-26 06:24:41 +00:00
|
|
|
from torch.fx import Graph, Node
|
2022-09-05 04:10:09 +00:00
|
|
|
from torch.utils._pytree import tree_map
|
|
|
|
|
|
|
|
|
2022-09-23 02:59:47 +00:00
|
|
|
def normalize_tuple(x):
|
|
|
|
if not isinstance(x, tuple):
|
|
|
|
return (x,)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def is_autogradable(x):
|
|
|
|
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
|
|
|
|
|
|
|
|
|
|
|
def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Graph:
|
2022-09-05 04:10:09 +00:00
|
|
|
"""Trace forward and backward graph with MetaTensor
|
|
|
|
|
|
|
|
Args:
|
|
|
|
module (torch.nn.Module): The target module for tracing.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
graph (torch.fx.Graph): The computation graph.
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
>>> import torchvision.models as tm
|
|
|
|
>>> model = tm.alexnet()
|
|
|
|
>>> graph = meta_trace(model, torch.rand(1000, 3, 224, 224))
|
|
|
|
>>> graph.print_tabular()
|
|
|
|
"""
|
|
|
|
graph = Graph()
|
2022-09-06 03:46:12 +00:00
|
|
|
namespace = graph._graph_namespace
|
2022-09-05 04:10:09 +00:00
|
|
|
|
|
|
|
class MetaProxy(torch.Tensor):
|
|
|
|
"""
|
|
|
|
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
|
|
|
|
"""
|
|
|
|
|
|
|
|
_tensor: torch.Tensor
|
|
|
|
_node: Node
|
|
|
|
|
|
|
|
__slots__ = ['_tensor', '_node']
|
|
|
|
|
|
|
|
@staticmethod
|
2022-09-23 02:59:47 +00:00
|
|
|
def __new__(cls, tensor, fake_device=None, placeholder=False, name=None):
|
2022-09-05 04:10:09 +00:00
|
|
|
r = torch.Tensor._make_wrapper_subclass(
|
|
|
|
cls,
|
|
|
|
tensor.size(),
|
|
|
|
strides=tensor.stride(),
|
|
|
|
storage_offset=tensor.storage_offset(),
|
|
|
|
dtype=tensor.dtype,
|
|
|
|
layout=tensor.layout,
|
2022-09-23 02:59:47 +00:00
|
|
|
device=fake_device if fake_device is not None else tensor.device,
|
2022-09-05 04:10:09 +00:00
|
|
|
requires_grad=tensor.requires_grad) # deceive the frontend for aten selections
|
|
|
|
r._tensor = tensor
|
|
|
|
if placeholder:
|
|
|
|
if name is None:
|
|
|
|
name = 'input'
|
|
|
|
r._node = graph.create_node('placeholder',
|
|
|
|
'placeholder', (graph._root,),
|
|
|
|
name=namespace.create_name(name, tensor))
|
|
|
|
# ...the real tensor is held as an element on the tensor.
|
2022-09-23 02:59:47 +00:00
|
|
|
if not r._tensor.is_meta:
|
|
|
|
r._tensor = r._tensor.to(torch.device('meta'))
|
2022-09-05 04:10:09 +00:00
|
|
|
return r
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
|
|
|
def unwrap(x):
|
2022-09-23 02:59:47 +00:00
|
|
|
nonlocal fake_device
|
|
|
|
if isinstance(x, MetaProxy):
|
|
|
|
fake_device = x.device
|
|
|
|
x = x._tensor
|
|
|
|
# assert not isinstance(x, MetaProxy)
|
|
|
|
elif isinstance(x, torch.Tensor):
|
|
|
|
fake_device = x.device
|
|
|
|
x = x.to(torch.device('meta'))
|
|
|
|
return x
|
2022-09-05 04:10:09 +00:00
|
|
|
|
|
|
|
def get_node(x):
|
|
|
|
if isinstance(x, torch.Tensor) and not hasattr(x, '_node'):
|
|
|
|
x = MetaProxy(x, placeholder=True, name='weight')
|
|
|
|
return x if not hasattr(x, '_node') else x._node
|
|
|
|
|
|
|
|
args_node = tree_map(get_node, args)
|
|
|
|
kwargs_node = tree_map(get_node, kwargs)
|
|
|
|
node = graph.create_node('call_function', func, args_node, kwargs_node)
|
|
|
|
|
2022-09-23 02:59:47 +00:00
|
|
|
if 'device' in kwargs:
|
|
|
|
fake_device = kwargs['device']
|
|
|
|
kwargs['device'] = torch.device('meta')
|
|
|
|
|
2022-09-05 04:10:09 +00:00
|
|
|
args = tree_map(unwrap, args)
|
|
|
|
kwargs = tree_map(unwrap, kwargs)
|
|
|
|
|
|
|
|
# run aten for backend=CPU but actually on backend=Meta
|
|
|
|
out = func(*args, **kwargs)
|
|
|
|
|
|
|
|
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
|
|
|
# our custom tensor subclass
|
|
|
|
def wrap(x):
|
2022-09-23 02:59:47 +00:00
|
|
|
if isinstance(x, torch.Tensor):
|
|
|
|
nonlocal fake_device
|
|
|
|
if not x.is_meta:
|
|
|
|
x = x.to(torch.device('meta'))
|
|
|
|
return MetaProxy(
|
|
|
|
x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
|
2022-09-05 04:10:09 +00:00
|
|
|
|
|
|
|
def set_node(x):
|
|
|
|
x._node = node
|
|
|
|
|
|
|
|
out = tree_map(wrap, out)
|
|
|
|
tree_map(set_node, out)
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
def wrap(x):
|
2022-09-23 02:59:47 +00:00
|
|
|
return MetaProxy(x, fake_device=fake_device, placeholder=True) if isinstance(x, torch.Tensor) else x
|
2022-09-05 04:10:09 +00:00
|
|
|
|
|
|
|
args = tree_map(wrap, args)
|
|
|
|
kwargs = tree_map(wrap, kwargs)
|
|
|
|
|
2022-09-23 02:59:47 +00:00
|
|
|
out = module(*args, **kwargs)
|
|
|
|
|
|
|
|
for tensor in normalize_tuple(out):
|
|
|
|
if is_autogradable(tensor) and tensor.requires_grad:
|
|
|
|
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
|
|
|
|
tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta'))
|
|
|
|
torch.autograd.backward(tensor,
|
|
|
|
MetaProxy(grad, fake_device=tensor.device, placeholder=True),
|
|
|
|
retain_graph=True)
|
2022-09-05 04:10:09 +00:00
|
|
|
return graph
|