diff --git a/colossalai/fx/__init__.py b/colossalai/fx/__init__.py index e69de29bb..ec6508a30 100644 --- a/colossalai/fx/__init__.py +++ b/colossalai/fx/__init__.py @@ -0,0 +1 @@ +from .tracer import ColoTracer diff --git a/colossalai/fx/proxy.py b/colossalai/fx/proxy.py index a6c02927c..72f9e646c 100644 --- a/colossalai/fx/proxy.py +++ b/colossalai/fx/proxy.py @@ -37,6 +37,12 @@ class ColoProxy(Proxy): def _assert_has_meta(self): assert self.has_meta_tensor, f'Meta tensor is not set for {self.node.name}' + @property + def device(self): + # Hack so we can track when devices are used. During meta-tensor propagation, + # replace these values with a constant 'meta' + return MetaDeviceAttribute(self, "device") + @property def dtype(self): self._assert_has_meta() @@ -72,3 +78,27 @@ class ColoProxy(Proxy): def __setitem__(self, indices, values): return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {}) + + +class ColoAttribute(ColoProxy): + + def __init__(self, root, attr: str): + # this class is copied from torch.fx.Attribute + # but inherits ColoProxy + self.root = root + self.attr = attr + self.tracer = root.tracer + self._node = None + + @property + def node(self): + if self._node is None: + self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node + return self._node + + def __call__(self, *args, **kwargs): + return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs) + + +class MetaDeviceAttribute(ColoAttribute): + pass diff --git a/colossalai/fx/tracer/__init__.py b/colossalai/fx/tracer/__init__.py new file mode 100644 index 000000000..ec6508a30 --- /dev/null +++ b/colossalai/fx/tracer/__init__.py @@ -0,0 +1 @@ +from .tracer import ColoTracer diff --git a/colossalai/fx/tracer/_tracer_utils.py b/colossalai/fx/tracer/_tracer_utils.py new file mode 100644 index 000000000..05ecd3843 --- /dev/null +++ b/colossalai/fx/tracer/_tracer_utils.py @@ -0,0 +1,31 @@ +from typing import List, Union, Any +from ..proxy import ColoProxy, MetaDeviceAttribute + +__all__ = ['is_element_in_list', 'extract_meta'] + + +def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]): + if isinstance(elements, (tuple, list, set)): + for ele in elements: + if ele not in list_: + return False, ele + else: + if elements not in list_: + return False, elements + + return True, None + + +def extract_meta(*args, **kwargs): + + def _convert(val): + if isinstance(val, MetaDeviceAttribute): + return 'meta' + elif isinstance(val, ColoProxy): + assert val.meta_tensor is not None + return val.meta_tensor + return val + + new_args = [_convert(val) for val in args] + new_kwargs = {k: _convert(v) for k, v in kwargs.items()} + return new_args, new_kwargs diff --git a/colossalai/fx/tracer/meta_patch/__init__.py b/colossalai/fx/tracer/meta_patch/__init__.py new file mode 100644 index 000000000..5f0c4745b --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/__init__.py @@ -0,0 +1,4 @@ +from sys import meta_path +from .registry import * +from .patched_function import * +from .patched_module import * diff --git a/colossalai/fx/tracer/meta_patch/patched_function.py b/colossalai/fx/tracer/meta_patch/patched_function.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/fx/tracer/meta_patch/patched_module.py b/colossalai/fx/tracer/meta_patch/patched_module.py new file mode 100644 index 000000000..f91436fb2 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module.py @@ -0,0 +1,7 @@ +import torch +from .registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.Linear) +def torch_nn_linear(self, input): + return torch.empty(input.shape[:-1] + (self.out_features,), device="meta") diff --git a/colossalai/fx/tracer/meta_patch/registry.py b/colossalai/fx/tracer/meta_patch/registry.py new file mode 100644 index 000000000..3eeafe448 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/registry.py @@ -0,0 +1,25 @@ +class PatchRegistry: + + def __init__(self, name): + self.name = name + self.store = {} + + def register(self, source): + + def wrapper(func): + self.store[source] = func + return func + + return wrapper + + def get(self, source): + assert source in self.store + target = self.store[source] + return target + + def has(self, source): + return source in self.store + + +meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution') +meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution') diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py new file mode 100644 index 000000000..dfeaa8b5c --- /dev/null +++ b/colossalai/fx/tracer/tracer.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python +""" +tracer.py: + Implemented a tracer which supports control flow and user-defined meta arguments. + The implementation is partly inspired HuggingFace's fx tracer +""" + +import inspect +import math +import functools +import torch +import torch.nn as nn +from torch import Tensor +from torch.fx import Tracer +from torch.fx.graph import Graph +from torch.fx.proxy import Proxy, ParameterProxy +from torch.utils import _pytree +from ..proxy import ColoProxy +from typing import Optional, Dict, Any +from ._tracer_utils import is_element_in_list, extract_meta +from .meta_patch import meta_patched_function, meta_patched_module + +__all__ = ['ColoTracer'] + + +class ColoTracer(Tracer): + """ + ColoTracer is a symbolic tracer designed to support dynamic control flow by using meta tensors for the `colossalai.fx` module. + This tracer is initialized in the same way as the original torch.fx.Tracer. + + Usage: + class Model(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(10, 10) + self.linear2 = nn.Linear(10, 10) + + def forward(self, x, y): + x1 = self.linear1(x) + y1 = self.linear2(y) + + if x1.dim() == 2: + return x1 + y1 + else: + return x1 - y1 + + model = Model() + tracer = ColoTracer() + graph = tracer.trace(model, concrete_args={'y': torch.rand(4, 10)}, meta_args={'x': torch.rand(4, 10, device='meta')}) + """ + + # Feature flag for proxying accesses to buffer values + proxy_buffer_attributes: bool = True + + _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"] + + def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy: + """ + Create a proxy for different kinds of operations. + """ + proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) + proxy: ColoProxy + + if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta: + proxy.meta_tensor = self.meta_args[target] + return proxy + + if target in self.orig_torch_tensor_methods: + # NOTE: tensor constructors in PyTorch define the `device` argument as + # *kwargs-only*. That is why this works. If you add methods to + # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only, + # this will break and you will likely see issues where we cannot infer + # the size of the output. + if "device" in kwargs: + kwargs["device"] = "meta" + + try: + args_metas, kwargs_metas = extract_meta(*args, **kwargs) + + if kind == "call_function": + # fetch patched function + if meta_patched_function.has(target): + meta_target = meta_patched_function.get(target) + else: + meta_target = target + + meta_out = meta_target(*args_metas, **kwargs_metas) + if isinstance(meta_out, torch.Tensor): + meta_out = meta_out.to(device="meta") + elif kind == "call_method": + method = getattr(args_metas[0].__class__, target) + + # fetch patched method + if meta_patched_function.has(method): + meta_target = meta_patched_function.get(method) + else: + meta_target = method + + meta_out = meta_target(*args_metas, **kwargs_metas) + elif kind == "call_module": + if not hasattr(self, "orig_forward"): + raise AttributeError(f"{self} does not have an attribute called orig_forward") + self._disable_module_getattr = True + try: + mod = self.root.get_submodule(target) + mod_type = type(mod) + if meta_patched_module.has(mod_type): + meta_out = meta_patched_module.get(mod_type)(mod, *args_metas, **kwargs_metas) + else: + meta_out = self.orig_forward(*args_metas, **kwargs_metas) + finally: + self._disable_module_getattr = False + elif kind == "get_attr": + self._disable_module_getattr = True + try: + attr_itr = self.root + atoms = target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + if isinstance(attr_itr, torch.Tensor): + meta_out = attr_itr.to(device="meta") + else: + meta_out = attr_itr + finally: + self._disable_module_getattr = False + else: + return proxy + + if not isinstance(proxy, Proxy): + raise ValueError("Don't support composite output yet") + proxy.meta_tensor = meta_out + except Exception as e: + raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}") + return proxy + + def _module_getattr(self, attr, attr_val, parameter_proxy_cache): + if getattr(self, "_disable_module_getattr", False): + return attr_val + else: + # return super()._module_getattr(attr, attr_val, parameter_proxy_cache) + def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache): + for n, p in collection_to_search: + if attr_val is p: + if n not in parameter_proxy_cache: + kwargs = {} + if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters: + kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else + lambda node: ParameterProxy(self, node, n, attr_val)) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] + parameter_proxy_cache[n] = val_proxy + return parameter_proxy_cache[n] + return None + + if isinstance(attr_val, torch.nn.Parameter): + maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), + parameter_proxy_cache) + if maybe_parameter_proxy is not None: + return maybe_parameter_proxy + + if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): + maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), + parameter_proxy_cache) + if maybe_buffer_proxy is not None: + return maybe_buffer_proxy + + return attr_val + + def call_module(self, m, forward, args, kwargs): + self.orig_forward = forward + return super().call_module(m, forward, args, kwargs) + + def proxy(self, node) -> ColoProxy: + """ + Returns a ColoProxy object. + """ + return ColoProxy(node, self) + + def trace(self, + root: nn.Module, + concrete_args: Optional[Dict[str, Tensor]] = None, + meta_args: Optional[Dict[str, Tensor]] = None) -> Graph: + """ + Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow. + + Args: + root (nn.Module): a `nn.Module` object to trace the computation graph + meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph. + These arguments are the sample data fed to the model during actual computation, but just converted to meta tensors. + concrete_args (Optional[Dict[str, Tensor]]): the concrete arguments that should not be treated as Proxies. + """ + if meta_args is None: + meta_args = {} + + if concrete_args is None: + concrete_args = {} + + # check concrete and meta args have valid names + sig = inspect.signature(root.forward) + sig_names = set(sig.parameters.keys()) + meta_arg_names = set(meta_args.keys()) + concrete_arg_names = set(concrete_args.keys()) + non_concrete_arg_names = sig_names - concrete_arg_names + + def _check_arg_name_valid(names): + success, element = is_element_in_list(names, sig_names) + if not success: + raise KeyError( + f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function") + + _check_arg_name_valid(meta_arg_names) + _check_arg_name_valid(concrete_arg_names) + + # assign as attributed for late reference + def _check_kwargs(kwargs, should_be_meta: bool): + for k, v in kwargs.items(): + assert v.is_meta == should_be_meta, \ + f'expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer' + + _check_kwargs(concrete_args, should_be_meta=False) + _check_kwargs(meta_args, should_be_meta=True) + + self.concrete_args = concrete_args + self.meta_args = meta_args + + # wrap the torch tensor constructing methods so that they are captured in the graph + self.patched_torch_tensor_methods = { + target: wrap_tensor_constructor_method(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH + } + + # patch these methods to replace their original use + for name, (wrapper, orig) in self.patched_torch_tensor_methods.items(): + setattr(torch, name, wrapper) + + # cache these methods so that we can detect whether a method call + # should be patched during tracing + self.orig_torch_tensor_methods = [val[1] for val in self.patched_torch_tensor_methods.values()] + + try: + self.graph = super().trace(root, concrete_args=concrete_args) + finally: + # recover the patched methods + for name, (_, orig) in self.patched_torch_tensor_methods.items(): + setattr(torch, name, orig) + + # This is necessary because concrete args are added as input to the traced module since + # https://github.com/pytorch/pytorch/pull/55888. + for node in self.graph.nodes: + if node.op == "placeholder": + # Removing default values for inputs as the forward pass will fail with them. + if node.target in non_concrete_arg_names: + node.args = () + # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor]. + # It cannot infer on the attributes and methods the input should have, and fails. + node.type = torch.Tensor + # It is a concrete arg so it is not used and should be removed. + else: + if hasattr(torch.fx._symbolic_trace, "_assert_is_none"): + # Newer versions of torch.fx emit an assert statement + # for concrete arguments; delete those before we delete + # the concrete arg. + to_delete = [] + for user in node.users: + if user.target == torch.fx._symbolic_trace._assert_is_none: + to_delete.append(user) + for user in to_delete: + self.graph.erase_node(user) + + self.graph.erase_node(node) + + # TODO: solves GraphModule creation. + # Without this, return type annotation "Tuple" is causing code execution failure. + if node.op == "output": + node.type = None + + return self.graph + + +def wrap_tensor_constructor_method(target): + + def look_for_proxy(*args, **kwargs): + # find in pos vars + for arg in args: + if isinstance(arg, Proxy): + return arg + + # find in keyword vars + for k, v in kwargs.items(): + if isinstance(v, Proxy): + return v + return None + + @functools.wraps(target) + def wrapper(*args, **kwargs): + proxy = look_for_proxy(*args, **kwargs) + + if proxy is not None: + # if the arg is a proxy, then need to record this function called on this proxy + # e.g. torch.ones(size) where size is an input proxy + return proxy.tracer.create_proxy("call_function", target, args, kwargs) + else: + # this is called directly when the inputs do not contain proxy + # e.g. torch.ones(4) where the input is static + return target(*args, **kwargs) + + return wrapper, target diff --git a/tests/test_fx/test_tracer/test_control_flow.py b/tests/test_fx/test_tracer/test_control_flow.py new file mode 100644 index 000000000..ed842cff2 --- /dev/null +++ b/tests/test_fx/test_tracer/test_control_flow.py @@ -0,0 +1,57 @@ +import torch +import torch.nn as nn +from torch.fx import GraphModule +from colossalai.fx import ColoTracer as Tracer + + +class ControlFlowModel(nn.Module): + + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(10, 10) + self.linear2 = nn.Linear(10, 10) + + def forward(self, x, y): + x1 = self.linear1(x) + y1 = self.linear2(y) + + if x1.dim() == 2: + return x1 + y1 + else: + return x1 - y1 + + +def test_control_flow(): + model = ControlFlowModel() + tracer = Tracer() + graph_branch_true = tracer.trace(model, + meta_args={ + 'x': torch.rand(4, 10, device='meta'), + 'y': torch.rand(4, 10, device='meta') + }) + graph_branch_false = tracer.trace(model, + meta_args={ + 'x': torch.rand(10, device='meta'), + 'y': torch.rand(4, 10, device='meta') + }) + + gm_branch_true = GraphModule(model, graph_branch_true, model.__class__.__name__) + gm_branch_false = GraphModule(model, graph_branch_false, model.__class__.__name__) + gm_branch_true.recompile() + gm_branch_false.recompile() + + # test the true branch + x = torch.rand(4, 10) + y = torch.rand(4, 10) + assert torch.all(model(x, y) == gm_branch_true(x, y)) + assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y)) + + # test the true branch + x = torch.rand(10) + y = torch.rand(4, 10) + assert torch.all(model(x, y) == gm_branch_false(x, y)) + assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y)) + + +if __name__ == '__main__': + test_control_flow()