mirror of https://github.com/hpcaitech/ColossalAI
[fx] supported data-dependent control flow in model tracing (#1185)
* [fx] supported data-dependent control flow in model tracing * polish codepull/1188/head
parent
c463f8adf9
commit
6d86f1bc91
|
@ -0,0 +1 @@
|
||||||
|
from .tracer import ColoTracer
|
|
@ -37,6 +37,12 @@ class ColoProxy(Proxy):
|
||||||
def _assert_has_meta(self):
|
def _assert_has_meta(self):
|
||||||
assert self.has_meta_tensor, f'Meta tensor is not set for {self.node.name}'
|
assert self.has_meta_tensor, f'Meta tensor is not set for {self.node.name}'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
# Hack so we can track when devices are used. During meta-tensor propagation,
|
||||||
|
# replace these values with a constant 'meta'
|
||||||
|
return MetaDeviceAttribute(self, "device")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
self._assert_has_meta()
|
self._assert_has_meta()
|
||||||
|
@ -72,3 +78,27 @@ class ColoProxy(Proxy):
|
||||||
|
|
||||||
def __setitem__(self, indices, values):
|
def __setitem__(self, indices, values):
|
||||||
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
|
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
|
||||||
|
|
||||||
|
|
||||||
|
class ColoAttribute(ColoProxy):
|
||||||
|
|
||||||
|
def __init__(self, root, attr: str):
|
||||||
|
# this class is copied from torch.fx.Attribute
|
||||||
|
# but inherits ColoProxy
|
||||||
|
self.root = root
|
||||||
|
self.attr = attr
|
||||||
|
self.tracer = root.tracer
|
||||||
|
self._node = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node(self):
|
||||||
|
if self._node is None:
|
||||||
|
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
|
||||||
|
return self._node
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class MetaDeviceAttribute(ColoAttribute):
|
||||||
|
pass
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
from .tracer import ColoTracer
|
|
@ -0,0 +1,31 @@
|
||||||
|
from typing import List, Union, Any
|
||||||
|
from ..proxy import ColoProxy, MetaDeviceAttribute
|
||||||
|
|
||||||
|
__all__ = ['is_element_in_list', 'extract_meta']
|
||||||
|
|
||||||
|
|
||||||
|
def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
|
||||||
|
if isinstance(elements, (tuple, list, set)):
|
||||||
|
for ele in elements:
|
||||||
|
if ele not in list_:
|
||||||
|
return False, ele
|
||||||
|
else:
|
||||||
|
if elements not in list_:
|
||||||
|
return False, elements
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_meta(*args, **kwargs):
|
||||||
|
|
||||||
|
def _convert(val):
|
||||||
|
if isinstance(val, MetaDeviceAttribute):
|
||||||
|
return 'meta'
|
||||||
|
elif isinstance(val, ColoProxy):
|
||||||
|
assert val.meta_tensor is not None
|
||||||
|
return val.meta_tensor
|
||||||
|
return val
|
||||||
|
|
||||||
|
new_args = [_convert(val) for val in args]
|
||||||
|
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
|
||||||
|
return new_args, new_kwargs
|
|
@ -0,0 +1,4 @@
|
||||||
|
from sys import meta_path
|
||||||
|
from .registry import *
|
||||||
|
from .patched_function import *
|
||||||
|
from .patched_module import *
|
|
@ -0,0 +1,7 @@
|
||||||
|
import torch
|
||||||
|
from .registry import meta_patched_module
|
||||||
|
|
||||||
|
|
||||||
|
@meta_patched_module.register(torch.nn.Linear)
|
||||||
|
def torch_nn_linear(self, input):
|
||||||
|
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
|
|
@ -0,0 +1,25 @@
|
||||||
|
class PatchRegistry:
|
||||||
|
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
self.store = {}
|
||||||
|
|
||||||
|
def register(self, source):
|
||||||
|
|
||||||
|
def wrapper(func):
|
||||||
|
self.store[source] = func
|
||||||
|
return func
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def get(self, source):
|
||||||
|
assert source in self.store
|
||||||
|
target = self.store[source]
|
||||||
|
return target
|
||||||
|
|
||||||
|
def has(self, source):
|
||||||
|
return source in self.store
|
||||||
|
|
||||||
|
|
||||||
|
meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution')
|
||||||
|
meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution')
|
|
@ -0,0 +1,305 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
tracer.py:
|
||||||
|
Implemented a tracer which supports control flow and user-defined meta arguments.
|
||||||
|
The implementation is partly inspired HuggingFace's fx tracer
|
||||||
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import math
|
||||||
|
import functools
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.fx import Tracer
|
||||||
|
from torch.fx.graph import Graph
|
||||||
|
from torch.fx.proxy import Proxy, ParameterProxy
|
||||||
|
from torch.utils import _pytree
|
||||||
|
from ..proxy import ColoProxy
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from ._tracer_utils import is_element_in_list, extract_meta
|
||||||
|
from .meta_patch import meta_patched_function, meta_patched_module
|
||||||
|
|
||||||
|
__all__ = ['ColoTracer']
|
||||||
|
|
||||||
|
|
||||||
|
class ColoTracer(Tracer):
|
||||||
|
"""
|
||||||
|
ColoTracer is a symbolic tracer designed to support dynamic control flow by using meta tensors for the `colossalai.fx` module.
|
||||||
|
This tracer is initialized in the same way as the original torch.fx.Tracer.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = nn.Linear(10, 10)
|
||||||
|
self.linear2 = nn.Linear(10, 10)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
x1 = self.linear1(x)
|
||||||
|
y1 = self.linear2(y)
|
||||||
|
|
||||||
|
if x1.dim() == 2:
|
||||||
|
return x1 + y1
|
||||||
|
else:
|
||||||
|
return x1 - y1
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
tracer = ColoTracer()
|
||||||
|
graph = tracer.trace(model, concrete_args={'y': torch.rand(4, 10)}, meta_args={'x': torch.rand(4, 10, device='meta')})
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Feature flag for proxying accesses to buffer values
|
||||||
|
proxy_buffer_attributes: bool = True
|
||||||
|
|
||||||
|
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"]
|
||||||
|
|
||||||
|
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy:
|
||||||
|
"""
|
||||||
|
Create a proxy for different kinds of operations.
|
||||||
|
"""
|
||||||
|
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
|
||||||
|
proxy: ColoProxy
|
||||||
|
|
||||||
|
if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta:
|
||||||
|
proxy.meta_tensor = self.meta_args[target]
|
||||||
|
return proxy
|
||||||
|
|
||||||
|
if target in self.orig_torch_tensor_methods:
|
||||||
|
# NOTE: tensor constructors in PyTorch define the `device` argument as
|
||||||
|
# *kwargs-only*. That is why this works. If you add methods to
|
||||||
|
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
|
||||||
|
# this will break and you will likely see issues where we cannot infer
|
||||||
|
# the size of the output.
|
||||||
|
if "device" in kwargs:
|
||||||
|
kwargs["device"] = "meta"
|
||||||
|
|
||||||
|
try:
|
||||||
|
args_metas, kwargs_metas = extract_meta(*args, **kwargs)
|
||||||
|
|
||||||
|
if kind == "call_function":
|
||||||
|
# fetch patched function
|
||||||
|
if meta_patched_function.has(target):
|
||||||
|
meta_target = meta_patched_function.get(target)
|
||||||
|
else:
|
||||||
|
meta_target = target
|
||||||
|
|
||||||
|
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||||||
|
if isinstance(meta_out, torch.Tensor):
|
||||||
|
meta_out = meta_out.to(device="meta")
|
||||||
|
elif kind == "call_method":
|
||||||
|
method = getattr(args_metas[0].__class__, target)
|
||||||
|
|
||||||
|
# fetch patched method
|
||||||
|
if meta_patched_function.has(method):
|
||||||
|
meta_target = meta_patched_function.get(method)
|
||||||
|
else:
|
||||||
|
meta_target = method
|
||||||
|
|
||||||
|
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||||||
|
elif kind == "call_module":
|
||||||
|
if not hasattr(self, "orig_forward"):
|
||||||
|
raise AttributeError(f"{self} does not have an attribute called orig_forward")
|
||||||
|
self._disable_module_getattr = True
|
||||||
|
try:
|
||||||
|
mod = self.root.get_submodule(target)
|
||||||
|
mod_type = type(mod)
|
||||||
|
if meta_patched_module.has(mod_type):
|
||||||
|
meta_out = meta_patched_module.get(mod_type)(mod, *args_metas, **kwargs_metas)
|
||||||
|
else:
|
||||||
|
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
|
||||||
|
finally:
|
||||||
|
self._disable_module_getattr = False
|
||||||
|
elif kind == "get_attr":
|
||||||
|
self._disable_module_getattr = True
|
||||||
|
try:
|
||||||
|
attr_itr = self.root
|
||||||
|
atoms = target.split(".")
|
||||||
|
for atom in atoms:
|
||||||
|
attr_itr = getattr(attr_itr, atom)
|
||||||
|
if isinstance(attr_itr, torch.Tensor):
|
||||||
|
meta_out = attr_itr.to(device="meta")
|
||||||
|
else:
|
||||||
|
meta_out = attr_itr
|
||||||
|
finally:
|
||||||
|
self._disable_module_getattr = False
|
||||||
|
else:
|
||||||
|
return proxy
|
||||||
|
|
||||||
|
if not isinstance(proxy, Proxy):
|
||||||
|
raise ValueError("Don't support composite output yet")
|
||||||
|
proxy.meta_tensor = meta_out
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}")
|
||||||
|
return proxy
|
||||||
|
|
||||||
|
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
|
||||||
|
if getattr(self, "_disable_module_getattr", False):
|
||||||
|
return attr_val
|
||||||
|
else:
|
||||||
|
# return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
|
||||||
|
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
|
||||||
|
for n, p in collection_to_search:
|
||||||
|
if attr_val is p:
|
||||||
|
if n not in parameter_proxy_cache:
|
||||||
|
kwargs = {}
|
||||||
|
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
|
||||||
|
kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else
|
||||||
|
lambda node: ParameterProxy(self, node, n, attr_val))
|
||||||
|
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
|
||||||
|
parameter_proxy_cache[n] = val_proxy
|
||||||
|
return parameter_proxy_cache[n]
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(attr_val, torch.nn.Parameter):
|
||||||
|
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
|
||||||
|
parameter_proxy_cache)
|
||||||
|
if maybe_parameter_proxy is not None:
|
||||||
|
return maybe_parameter_proxy
|
||||||
|
|
||||||
|
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
|
||||||
|
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(),
|
||||||
|
parameter_proxy_cache)
|
||||||
|
if maybe_buffer_proxy is not None:
|
||||||
|
return maybe_buffer_proxy
|
||||||
|
|
||||||
|
return attr_val
|
||||||
|
|
||||||
|
def call_module(self, m, forward, args, kwargs):
|
||||||
|
self.orig_forward = forward
|
||||||
|
return super().call_module(m, forward, args, kwargs)
|
||||||
|
|
||||||
|
def proxy(self, node) -> ColoProxy:
|
||||||
|
"""
|
||||||
|
Returns a ColoProxy object.
|
||||||
|
"""
|
||||||
|
return ColoProxy(node, self)
|
||||||
|
|
||||||
|
def trace(self,
|
||||||
|
root: nn.Module,
|
||||||
|
concrete_args: Optional[Dict[str, Tensor]] = None,
|
||||||
|
meta_args: Optional[Dict[str, Tensor]] = None) -> Graph:
|
||||||
|
"""
|
||||||
|
Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root (nn.Module): a `nn.Module` object to trace the computation graph
|
||||||
|
meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph.
|
||||||
|
These arguments are the sample data fed to the model during actual computation, but just converted to meta tensors.
|
||||||
|
concrete_args (Optional[Dict[str, Tensor]]): the concrete arguments that should not be treated as Proxies.
|
||||||
|
"""
|
||||||
|
if meta_args is None:
|
||||||
|
meta_args = {}
|
||||||
|
|
||||||
|
if concrete_args is None:
|
||||||
|
concrete_args = {}
|
||||||
|
|
||||||
|
# check concrete and meta args have valid names
|
||||||
|
sig = inspect.signature(root.forward)
|
||||||
|
sig_names = set(sig.parameters.keys())
|
||||||
|
meta_arg_names = set(meta_args.keys())
|
||||||
|
concrete_arg_names = set(concrete_args.keys())
|
||||||
|
non_concrete_arg_names = sig_names - concrete_arg_names
|
||||||
|
|
||||||
|
def _check_arg_name_valid(names):
|
||||||
|
success, element = is_element_in_list(names, sig_names)
|
||||||
|
if not success:
|
||||||
|
raise KeyError(
|
||||||
|
f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function")
|
||||||
|
|
||||||
|
_check_arg_name_valid(meta_arg_names)
|
||||||
|
_check_arg_name_valid(concrete_arg_names)
|
||||||
|
|
||||||
|
# assign as attributed for late reference
|
||||||
|
def _check_kwargs(kwargs, should_be_meta: bool):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
assert v.is_meta == should_be_meta, \
|
||||||
|
f'expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer'
|
||||||
|
|
||||||
|
_check_kwargs(concrete_args, should_be_meta=False)
|
||||||
|
_check_kwargs(meta_args, should_be_meta=True)
|
||||||
|
|
||||||
|
self.concrete_args = concrete_args
|
||||||
|
self.meta_args = meta_args
|
||||||
|
|
||||||
|
# wrap the torch tensor constructing methods so that they are captured in the graph
|
||||||
|
self.patched_torch_tensor_methods = {
|
||||||
|
target: wrap_tensor_constructor_method(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
|
||||||
|
}
|
||||||
|
|
||||||
|
# patch these methods to replace their original use
|
||||||
|
for name, (wrapper, orig) in self.patched_torch_tensor_methods.items():
|
||||||
|
setattr(torch, name, wrapper)
|
||||||
|
|
||||||
|
# cache these methods so that we can detect whether a method call
|
||||||
|
# should be patched during tracing
|
||||||
|
self.orig_torch_tensor_methods = [val[1] for val in self.patched_torch_tensor_methods.values()]
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.graph = super().trace(root, concrete_args=concrete_args)
|
||||||
|
finally:
|
||||||
|
# recover the patched methods
|
||||||
|
for name, (_, orig) in self.patched_torch_tensor_methods.items():
|
||||||
|
setattr(torch, name, orig)
|
||||||
|
|
||||||
|
# This is necessary because concrete args are added as input to the traced module since
|
||||||
|
# https://github.com/pytorch/pytorch/pull/55888.
|
||||||
|
for node in self.graph.nodes:
|
||||||
|
if node.op == "placeholder":
|
||||||
|
# Removing default values for inputs as the forward pass will fail with them.
|
||||||
|
if node.target in non_concrete_arg_names:
|
||||||
|
node.args = ()
|
||||||
|
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
|
||||||
|
# It cannot infer on the attributes and methods the input should have, and fails.
|
||||||
|
node.type = torch.Tensor
|
||||||
|
# It is a concrete arg so it is not used and should be removed.
|
||||||
|
else:
|
||||||
|
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
|
||||||
|
# Newer versions of torch.fx emit an assert statement
|
||||||
|
# for concrete arguments; delete those before we delete
|
||||||
|
# the concrete arg.
|
||||||
|
to_delete = []
|
||||||
|
for user in node.users:
|
||||||
|
if user.target == torch.fx._symbolic_trace._assert_is_none:
|
||||||
|
to_delete.append(user)
|
||||||
|
for user in to_delete:
|
||||||
|
self.graph.erase_node(user)
|
||||||
|
|
||||||
|
self.graph.erase_node(node)
|
||||||
|
|
||||||
|
# TODO: solves GraphModule creation.
|
||||||
|
# Without this, return type annotation "Tuple" is causing code execution failure.
|
||||||
|
if node.op == "output":
|
||||||
|
node.type = None
|
||||||
|
|
||||||
|
return self.graph
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_tensor_constructor_method(target):
|
||||||
|
|
||||||
|
def look_for_proxy(*args, **kwargs):
|
||||||
|
# find in pos vars
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, Proxy):
|
||||||
|
return arg
|
||||||
|
|
||||||
|
# find in keyword vars
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if isinstance(v, Proxy):
|
||||||
|
return v
|
||||||
|
return None
|
||||||
|
|
||||||
|
@functools.wraps(target)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
proxy = look_for_proxy(*args, **kwargs)
|
||||||
|
|
||||||
|
if proxy is not None:
|
||||||
|
# if the arg is a proxy, then need to record this function called on this proxy
|
||||||
|
# e.g. torch.ones(size) where size is an input proxy
|
||||||
|
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
|
||||||
|
else:
|
||||||
|
# this is called directly when the inputs do not contain proxy
|
||||||
|
# e.g. torch.ones(4) where the input is static
|
||||||
|
return target(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper, target
|
|
@ -0,0 +1,57 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
from colossalai.fx import ColoTracer as Tracer
|
||||||
|
|
||||||
|
|
||||||
|
class ControlFlowModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = nn.Linear(10, 10)
|
||||||
|
self.linear2 = nn.Linear(10, 10)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
x1 = self.linear1(x)
|
||||||
|
y1 = self.linear2(y)
|
||||||
|
|
||||||
|
if x1.dim() == 2:
|
||||||
|
return x1 + y1
|
||||||
|
else:
|
||||||
|
return x1 - y1
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_flow():
|
||||||
|
model = ControlFlowModel()
|
||||||
|
tracer = Tracer()
|
||||||
|
graph_branch_true = tracer.trace(model,
|
||||||
|
meta_args={
|
||||||
|
'x': torch.rand(4, 10, device='meta'),
|
||||||
|
'y': torch.rand(4, 10, device='meta')
|
||||||
|
})
|
||||||
|
graph_branch_false = tracer.trace(model,
|
||||||
|
meta_args={
|
||||||
|
'x': torch.rand(10, device='meta'),
|
||||||
|
'y': torch.rand(4, 10, device='meta')
|
||||||
|
})
|
||||||
|
|
||||||
|
gm_branch_true = GraphModule(model, graph_branch_true, model.__class__.__name__)
|
||||||
|
gm_branch_false = GraphModule(model, graph_branch_false, model.__class__.__name__)
|
||||||
|
gm_branch_true.recompile()
|
||||||
|
gm_branch_false.recompile()
|
||||||
|
|
||||||
|
# test the true branch
|
||||||
|
x = torch.rand(4, 10)
|
||||||
|
y = torch.rand(4, 10)
|
||||||
|
assert torch.all(model(x, y) == gm_branch_true(x, y))
|
||||||
|
assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y))
|
||||||
|
|
||||||
|
# test the true branch
|
||||||
|
x = torch.rand(10)
|
||||||
|
y = torch.rand(4, 10)
|
||||||
|
assert torch.all(model(x, y) == gm_branch_false(x, y))
|
||||||
|
assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_control_flow()
|
Loading…
Reference in New Issue