diff --git a/colossalai/__init__.py b/colossalai/__init__.py index 697b73a74..b5fff7469 100644 --- a/colossalai/__init__.py +++ b/colossalai/__init__.py @@ -1,4 +1,9 @@ +try: + from ._meta_registrations import * +except: + import torch + print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.') from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch, get_default_parser) -__version__ = '0.0.1' +__version__ = '0.1.9' diff --git a/colossalai/fx/profiler/_meta_registrations.py b/colossalai/_meta_registrations.py similarity index 100% rename from colossalai/fx/profiler/_meta_registrations.py rename to colossalai/_meta_registrations.py diff --git a/colossalai/fx/__init__.py b/colossalai/fx/__init__.py index 6513f6d03..6d0475f70 100644 --- a/colossalai/fx/__init__.py +++ b/colossalai/fx/__init__.py @@ -1,2 +1,2 @@ -from .tracer import ColoTracer +from .tracer import ColoTracer, meta_trace from .graph_module import ColoGraphModule diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index 4b90bcb30..9d657ad22 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -1,8 +1,3 @@ -try: - from ._meta_registrations import * -except: - import torch - print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.') from .meta_tensor import MetaTensor from .registry import meta_profiler_function, meta_profiler_module from .profiler_function import * diff --git a/colossalai/fx/tracer/__init__.py b/colossalai/fx/tracer/__init__.py index ec6508a30..327e1510e 100644 --- a/colossalai/fx/tracer/__init__.py +++ b/colossalai/fx/tracer/__init__.py @@ -1 +1,2 @@ from .tracer import ColoTracer +from ._meta_trace import meta_trace diff --git a/colossalai/fx/tracer/_meta_trace.py b/colossalai/fx/tracer/_meta_trace.py new file mode 100644 index 000000000..48b3e2deb --- /dev/null +++ b/colossalai/fx/tracer/_meta_trace.py @@ -0,0 +1,99 @@ +import torch +from torch.fx import Node, Graph +from torch.fx.graph import _Namespace +from torch.utils._pytree import tree_map + + +def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph: + """Trace forward and backward graph with MetaTensor + + Args: + module (torch.nn.Module): The target module for tracing. + + Returns: + graph (torch.fx.Graph): The computation graph. + + Usage: + >>> import torchvision.models as tm + >>> model = tm.alexnet() + >>> graph = meta_trace(model, torch.rand(1000, 3, 224, 224)) + >>> graph.print_tabular() + """ + graph = Graph() + namespace = _Namespace() + + class MetaProxy(torch.Tensor): + """ + A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops. + """ + + _tensor: torch.Tensor + _node: Node + + __slots__ = ['_tensor', '_node'] + + @staticmethod + def __new__(cls, tensor, placeholder=False, name=None): + r = torch.Tensor._make_wrapper_subclass( + cls, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + dtype=tensor.dtype, + layout=tensor.layout, + device='cpu', + requires_grad=tensor.requires_grad) # deceive the frontend for aten selections + r._tensor = tensor + if placeholder: + if name is None: + name = 'input' + r._node = graph.create_node('placeholder', + 'placeholder', (graph._root,), + name=namespace.create_name(name, tensor)) + # ...the real tensor is held as an element on the tensor. + return r + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + + def unwrap(x): + if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): + x = MetaProxy(x) + return x._tensor.to('meta') if isinstance(x, MetaProxy) else x + + def get_node(x): + if isinstance(x, torch.Tensor) and not hasattr(x, '_node'): + x = MetaProxy(x, placeholder=True, name='weight') + return x if not hasattr(x, '_node') else x._node + + args_node = tree_map(get_node, args) + kwargs_node = tree_map(get_node, kwargs) + node = graph.create_node('call_function', func, args_node, kwargs_node) + + args = tree_map(unwrap, args) + kwargs = tree_map(unwrap, kwargs) + + # run aten for backend=CPU but actually on backend=Meta + out = func(*args, **kwargs) + + # Now, we want to continue propagating this tensor, so we rewrap Tensors in + # our custom tensor subclass + def wrap(x): + return MetaProxy(x) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x + + def set_node(x): + x._node = node + + out = tree_map(wrap, out) + tree_map(set_node, out) + + return out + + def wrap(x): + return MetaProxy(x, True) if isinstance(x, torch.Tensor) else x + + args = tree_map(wrap, args) + kwargs = tree_map(wrap, kwargs) + + module(*args, **kwargs).sum().backward() + return graph