import torch from torch.fx import Graph, Node from torch.utils._pytree import tree_map def normalize_tuple(x): if not isinstance(x, tuple): return (x,) return x def is_autogradable(x): return isinstance(x, torch.Tensor) and x.is_floating_point() def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Graph: """Trace forward and backward graph with MetaTensor Args: module (torch.nn.Module): The target module for tracing. Returns: graph (torch.fx.Graph): The computation graph. Usage: >>> import torchvision.models as tm >>> model = tm.alexnet() >>> graph = meta_trace(model, torch.rand(1000, 3, 224, 224)) >>> graph.print_tabular() """ graph = Graph() namespace = graph._graph_namespace class MetaProxy(torch.Tensor): """ A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops. """ _tensor: torch.Tensor _node: Node __slots__ = ['_tensor', '_node'] @staticmethod def __new__(cls, tensor, fake_device=None, placeholder=False, name=None): r = torch.Tensor._make_wrapper_subclass( cls, tensor.size(), strides=tensor.stride(), storage_offset=tensor.storage_offset(), dtype=tensor.dtype, layout=tensor.layout, device=fake_device if fake_device is not None else tensor.device, requires_grad=tensor.requires_grad) # deceive the frontend for aten selections r._tensor = tensor if placeholder: if name is None: name = 'input' r._node = graph.create_node('placeholder', 'placeholder', (graph._root,), name=namespace.create_name(name, tensor)) # ...the real tensor is held as an element on the tensor. if not r._tensor.is_meta: r._tensor = r._tensor.to(torch.device('meta')) return r @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(x): nonlocal fake_device if isinstance(x, MetaProxy): fake_device = x.device x = x._tensor # assert not isinstance(x, MetaProxy) elif isinstance(x, torch.Tensor): fake_device = x.device x = x.to(torch.device('meta')) return x def get_node(x): if isinstance(x, torch.Tensor) and not hasattr(x, '_node'): x = MetaProxy(x, placeholder=True, name='weight') return x if not hasattr(x, '_node') else x._node args_node = tree_map(get_node, args) kwargs_node = tree_map(get_node, kwargs) node = graph.create_node('call_function', func, args_node, kwargs_node) if 'device' in kwargs: fake_device = kwargs['device'] kwargs['device'] = torch.device('meta') args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) # run aten for backend=CPU but actually on backend=Meta out = func(*args, **kwargs) # Now, we want to continue propagating this tensor, so we rewrap Tensors in # our custom tensor subclass def wrap(x): if isinstance(x, torch.Tensor): nonlocal fake_device if not x.is_meta: x = x.to(torch.device('meta')) return MetaProxy( x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x def set_node(x): x._node = node out = tree_map(wrap, out) tree_map(set_node, out) return out def wrap(x): return MetaProxy(x, fake_device=fake_device, placeholder=True) if isinstance(x, torch.Tensor) else x args = tree_map(wrap, args) kwargs = tree_map(wrap, kwargs) out = module(*args, **kwargs) for tensor in normalize_tuple(out): if is_autogradable(tensor) and tensor.requires_grad: grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance( tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta')) torch.autograd.backward(tensor, MetaProxy(grad, fake_device=tensor.device, placeholder=True), retain_graph=True) return graph