[fx] supported data-dependent control flow in model tracing (#1185)

* [fx] supported data-dependent control flow in model tracing

* polish code
pull/1188/head
Frank Lee 2022-06-29 15:05:25 +08:00 committed by GitHub
parent c463f8adf9
commit 6d86f1bc91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 461 additions and 0 deletions

View File

@ -0,0 +1 @@
from .tracer import ColoTracer

View File

@ -37,6 +37,12 @@ class ColoProxy(Proxy):
def _assert_has_meta(self):
assert self.has_meta_tensor, f'Meta tensor is not set for {self.node.name}'
@property
def device(self):
# Hack so we can track when devices are used. During meta-tensor propagation,
# replace these values with a constant 'meta'
return MetaDeviceAttribute(self, "device")
@property
def dtype(self):
self._assert_has_meta()
@ -72,3 +78,27 @@ class ColoProxy(Proxy):
def __setitem__(self, indices, values):
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
class ColoAttribute(ColoProxy):
def __init__(self, root, attr: str):
# this class is copied from torch.fx.Attribute
# but inherits ColoProxy
self.root = root
self.attr = attr
self.tracer = root.tracer
self._node = None
@property
def node(self):
if self._node is None:
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
class MetaDeviceAttribute(ColoAttribute):
pass

View File

@ -0,0 +1 @@
from .tracer import ColoTracer

View File

@ -0,0 +1,31 @@
from typing import List, Union, Any
from ..proxy import ColoProxy, MetaDeviceAttribute
__all__ = ['is_element_in_list', 'extract_meta']
def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
if isinstance(elements, (tuple, list, set)):
for ele in elements:
if ele not in list_:
return False, ele
else:
if elements not in list_:
return False, elements
return True, None
def extract_meta(*args, **kwargs):
def _convert(val):
if isinstance(val, MetaDeviceAttribute):
return 'meta'
elif isinstance(val, ColoProxy):
assert val.meta_tensor is not None
return val.meta_tensor
return val
new_args = [_convert(val) for val in args]
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
return new_args, new_kwargs

View File

@ -0,0 +1,4 @@
from sys import meta_path
from .registry import *
from .patched_function import *
from .patched_module import *

View File

@ -0,0 +1,7 @@
import torch
from .registry import meta_patched_module
@meta_patched_module.register(torch.nn.Linear)
def torch_nn_linear(self, input):
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")

View File

@ -0,0 +1,25 @@
class PatchRegistry:
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
def wrapper(func):
self.store[source] = func
return func
return wrapper
def get(self, source):
assert source in self.store
target = self.store[source]
return target
def has(self, source):
return source in self.store
meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution')
meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution')

View File

@ -0,0 +1,305 @@
#!/usr/bin/env python
"""
tracer.py:
Implemented a tracer which supports control flow and user-defined meta arguments.
The implementation is partly inspired HuggingFace's fx tracer
"""
import inspect
import math
import functools
import torch
import torch.nn as nn
from torch import Tensor
from torch.fx import Tracer
from torch.fx.graph import Graph
from torch.fx.proxy import Proxy, ParameterProxy
from torch.utils import _pytree
from ..proxy import ColoProxy
from typing import Optional, Dict, Any
from ._tracer_utils import is_element_in_list, extract_meta
from .meta_patch import meta_patched_function, meta_patched_module
__all__ = ['ColoTracer']
class ColoTracer(Tracer):
"""
ColoTracer is a symbolic tracer designed to support dynamic control flow by using meta tensors for the `colossalai.fx` module.
This tracer is initialized in the same way as the original torch.fx.Tracer.
Usage:
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 10)
self.linear2 = nn.Linear(10, 10)
def forward(self, x, y):
x1 = self.linear1(x)
y1 = self.linear2(y)
if x1.dim() == 2:
return x1 + y1
else:
return x1 - y1
model = Model()
tracer = ColoTracer()
graph = tracer.trace(model, concrete_args={'y': torch.rand(4, 10)}, meta_args={'x': torch.rand(4, 10, device='meta')})
"""
# Feature flag for proxying accesses to buffer values
proxy_buffer_attributes: bool = True
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"]
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy:
"""
Create a proxy for different kinds of operations.
"""
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
proxy: ColoProxy
if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta:
proxy.meta_tensor = self.meta_args[target]
return proxy
if target in self.orig_torch_tensor_methods:
# NOTE: tensor constructors in PyTorch define the `device` argument as
# *kwargs-only*. That is why this works. If you add methods to
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
# this will break and you will likely see issues where we cannot infer
# the size of the output.
if "device" in kwargs:
kwargs["device"] = "meta"
try:
args_metas, kwargs_metas = extract_meta(*args, **kwargs)
if kind == "call_function":
# fetch patched function
if meta_patched_function.has(target):
meta_target = meta_patched_function.get(target)
else:
meta_target = target
meta_out = meta_target(*args_metas, **kwargs_metas)
if isinstance(meta_out, torch.Tensor):
meta_out = meta_out.to(device="meta")
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
# fetch patched method
if meta_patched_function.has(method):
meta_target = meta_patched_function.get(method)
else:
meta_target = method
meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == "call_module":
if not hasattr(self, "orig_forward"):
raise AttributeError(f"{self} does not have an attribute called orig_forward")
self._disable_module_getattr = True
try:
mod = self.root.get_submodule(target)
mod_type = type(mod)
if meta_patched_module.has(mod_type):
meta_out = meta_patched_module.get(mod_type)(mod, *args_metas, **kwargs_metas)
else:
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
finally:
self._disable_module_getattr = False
elif kind == "get_attr":
self._disable_module_getattr = True
try:
attr_itr = self.root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.Tensor):
meta_out = attr_itr.to(device="meta")
else:
meta_out = attr_itr
finally:
self._disable_module_getattr = False
else:
return proxy
if not isinstance(proxy, Proxy):
raise ValueError("Don't support composite output yet")
proxy.meta_tensor = meta_out
except Exception as e:
raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}")
return proxy
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, "_disable_module_getattr", False):
return attr_val
else:
# return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
for n, p in collection_to_search:
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else
lambda node: ParameterProxy(self, node, n, attr_val))
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
parameter_proxy_cache)
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(),
parameter_proxy_cache)
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
return attr_val
def call_module(self, m, forward, args, kwargs):
self.orig_forward = forward
return super().call_module(m, forward, args, kwargs)
def proxy(self, node) -> ColoProxy:
"""
Returns a ColoProxy object.
"""
return ColoProxy(node, self)
def trace(self,
root: nn.Module,
concrete_args: Optional[Dict[str, Tensor]] = None,
meta_args: Optional[Dict[str, Tensor]] = None) -> Graph:
"""
Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow.
Args:
root (nn.Module): a `nn.Module` object to trace the computation graph
meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph.
These arguments are the sample data fed to the model during actual computation, but just converted to meta tensors.
concrete_args (Optional[Dict[str, Tensor]]): the concrete arguments that should not be treated as Proxies.
"""
if meta_args is None:
meta_args = {}
if concrete_args is None:
concrete_args = {}
# check concrete and meta args have valid names
sig = inspect.signature(root.forward)
sig_names = set(sig.parameters.keys())
meta_arg_names = set(meta_args.keys())
concrete_arg_names = set(concrete_args.keys())
non_concrete_arg_names = sig_names - concrete_arg_names
def _check_arg_name_valid(names):
success, element = is_element_in_list(names, sig_names)
if not success:
raise KeyError(
f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function")
_check_arg_name_valid(meta_arg_names)
_check_arg_name_valid(concrete_arg_names)
# assign as attributed for late reference
def _check_kwargs(kwargs, should_be_meta: bool):
for k, v in kwargs.items():
assert v.is_meta == should_be_meta, \
f'expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer'
_check_kwargs(concrete_args, should_be_meta=False)
_check_kwargs(meta_args, should_be_meta=True)
self.concrete_args = concrete_args
self.meta_args = meta_args
# wrap the torch tensor constructing methods so that they are captured in the graph
self.patched_torch_tensor_methods = {
target: wrap_tensor_constructor_method(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
}
# patch these methods to replace their original use
for name, (wrapper, orig) in self.patched_torch_tensor_methods.items():
setattr(torch, name, wrapper)
# cache these methods so that we can detect whether a method call
# should be patched during tracing
self.orig_torch_tensor_methods = [val[1] for val in self.patched_torch_tensor_methods.values()]
try:
self.graph = super().trace(root, concrete_args=concrete_args)
finally:
# recover the patched methods
for name, (_, orig) in self.patched_torch_tensor_methods.items():
setattr(torch, name, orig)
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
for node in self.graph.nodes:
if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them.
if node.target in non_concrete_arg_names:
node.args = ()
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
# It cannot infer on the attributes and methods the input should have, and fails.
node.type = torch.Tensor
# It is a concrete arg so it is not used and should be removed.
else:
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
# Newer versions of torch.fx emit an assert statement
# for concrete arguments; delete those before we delete
# the concrete arg.
to_delete = []
for user in node.users:
if user.target == torch.fx._symbolic_trace._assert_is_none:
to_delete.append(user)
for user in to_delete:
self.graph.erase_node(user)
self.graph.erase_node(node)
# TODO: solves GraphModule creation.
# Without this, return type annotation "Tuple" is causing code execution failure.
if node.op == "output":
node.type = None
return self.graph
def wrap_tensor_constructor_method(target):
def look_for_proxy(*args, **kwargs):
# find in pos vars
for arg in args:
if isinstance(arg, Proxy):
return arg
# find in keyword vars
for k, v in kwargs.items():
if isinstance(v, Proxy):
return v
return None
@functools.wraps(target)
def wrapper(*args, **kwargs):
proxy = look_for_proxy(*args, **kwargs)
if proxy is not None:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
else:
# this is called directly when the inputs do not contain proxy
# e.g. torch.ones(4) where the input is static
return target(*args, **kwargs)
return wrapper, target

View File

@ -0,0 +1,57 @@
import torch
import torch.nn as nn
from torch.fx import GraphModule
from colossalai.fx import ColoTracer as Tracer
class ControlFlowModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 10)
self.linear2 = nn.Linear(10, 10)
def forward(self, x, y):
x1 = self.linear1(x)
y1 = self.linear2(y)
if x1.dim() == 2:
return x1 + y1
else:
return x1 - y1
def test_control_flow():
model = ControlFlowModel()
tracer = Tracer()
graph_branch_true = tracer.trace(model,
meta_args={
'x': torch.rand(4, 10, device='meta'),
'y': torch.rand(4, 10, device='meta')
})
graph_branch_false = tracer.trace(model,
meta_args={
'x': torch.rand(10, device='meta'),
'y': torch.rand(4, 10, device='meta')
})
gm_branch_true = GraphModule(model, graph_branch_true, model.__class__.__name__)
gm_branch_false = GraphModule(model, graph_branch_false, model.__class__.__name__)
gm_branch_true.recompile()
gm_branch_false.recompile()
# test the true branch
x = torch.rand(4, 10)
y = torch.rand(4, 10)
assert torch.all(model(x, y) == gm_branch_true(x, y))
assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y))
# test the true branch
x = torch.rand(10)
y = torch.rand(4, 10)
assert torch.all(model(x, y) == gm_branch_false(x, y))
assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y))
if __name__ == '__main__':
test_control_flow()