mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
355 lines
15 KiB
355 lines
15 KiB
#!/usr/bin/env python
|
|
"""
|
|
tracer.py:
|
|
Implemented a tracer which supports control flow and user-defined meta arguments.
|
|
The implementation is partly inspired HuggingFace's fx tracer
|
|
"""
|
|
import enum
|
|
import inspect
|
|
import functools
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
from torch.fx import Tracer
|
|
from torch.fx.graph import Graph
|
|
from torch.fx.proxy import Proxy, ParameterProxy
|
|
from ..proxy import ColoProxy
|
|
from typing import Optional, Dict, Any
|
|
from ._tracer_utils import is_element_in_list, extract_meta
|
|
from .meta_patch import meta_patched_function, meta_patched_module
|
|
|
|
__all__ = ['ColoTracer']
|
|
|
|
|
|
class TracerType(enum.Enum):
|
|
DEFAULT = 1
|
|
META = 2
|
|
|
|
|
|
class ColoTracer(Tracer):
|
|
"""
|
|
ColoTracer is a symbolic tracer designed to support dynamic control flow by using meta tensors for the `colossalai.fx` module.
|
|
This tracer is initialized in the same way as the original torch.fx.Tracer.
|
|
|
|
Usage:
|
|
class Model(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = nn.Linear(10, 10)
|
|
self.linear2 = nn.Linear(10, 10)
|
|
|
|
def forward(self, x, y):
|
|
x1 = self.linear1(x)
|
|
y1 = self.linear2(y)
|
|
|
|
if x1.dim() == 2:
|
|
return x1 + y1
|
|
else:
|
|
return x1 - y1
|
|
|
|
model = Model()
|
|
tracer = ColoTracer()
|
|
graph = tracer.trace(model, concrete_args={'y': torch.rand(4, 10)}, meta_args={'x': torch.rand(4, 10, device='meta')})
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.tracer_type = TracerType.META
|
|
self.proxy_cls = ColoProxy
|
|
|
|
# Feature flag for proxying accesses to buffer values
|
|
proxy_buffer_attributes: bool = True
|
|
|
|
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"]
|
|
|
|
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy:
|
|
"""
|
|
Create a proxy for different kinds of operations.
|
|
"""
|
|
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
|
|
|
|
if self.tracer_type == TracerType.DEFAULT:
|
|
# since meta_args is not given
|
|
# we just fall back to the original torch.fx.Tracer
|
|
return proxy
|
|
|
|
proxy: ColoProxy
|
|
|
|
if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta:
|
|
proxy.meta_data = self.meta_args[target]
|
|
return proxy
|
|
|
|
if target in self.orig_torch_tensor_methods:
|
|
# NOTE: tensor constructors in PyTorch define the `device` argument as
|
|
# *kwargs-only*. That is why this works. If you add methods to
|
|
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
|
|
# this will break and you will likely see issues where we cannot infer
|
|
# the size of the output.
|
|
if "device" in kwargs:
|
|
kwargs["device"] = "meta"
|
|
|
|
try:
|
|
args_metas, kwargs_metas = extract_meta(*args, **kwargs)
|
|
|
|
if kind == "call_function":
|
|
# fetch patched function
|
|
if meta_patched_function.has(target):
|
|
meta_target = meta_patched_function.get(target)
|
|
else:
|
|
meta_target = target
|
|
|
|
meta_out = meta_target(*args_metas, **kwargs_metas)
|
|
if isinstance(meta_out, torch.Tensor):
|
|
meta_out = meta_out.to(device="meta")
|
|
elif kind == "call_method":
|
|
method = getattr(args_metas[0].__class__, target)
|
|
|
|
# fetch patched method
|
|
if meta_patched_function.has(method):
|
|
meta_target = meta_patched_function.get(method)
|
|
else:
|
|
meta_target = method
|
|
|
|
meta_out = meta_target(*args_metas, **kwargs_metas)
|
|
elif kind == "call_module":
|
|
if not hasattr(self, "orig_forward"):
|
|
raise AttributeError(f"{self} does not have an attribute called orig_forward")
|
|
self._disable_module_getattr = True
|
|
try:
|
|
mod = self.root.get_submodule(target)
|
|
mod_type = type(mod)
|
|
if meta_patched_module.has(mod_type):
|
|
meta_out = meta_patched_module.get(mod_type)(mod, *args_metas, **kwargs_metas)
|
|
else:
|
|
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
|
|
finally:
|
|
self._disable_module_getattr = False
|
|
elif kind == "get_attr":
|
|
self._disable_module_getattr = True
|
|
try:
|
|
attr_itr = self.root
|
|
atoms = target.split(".")
|
|
for atom in atoms:
|
|
attr_itr = getattr(attr_itr, atom)
|
|
if isinstance(attr_itr, torch.Tensor):
|
|
meta_out = attr_itr.to(device="meta")
|
|
else:
|
|
meta_out = attr_itr
|
|
finally:
|
|
self._disable_module_getattr = False
|
|
else:
|
|
return proxy
|
|
|
|
if not isinstance(proxy, Proxy):
|
|
raise ValueError("Don't support composite output yet")
|
|
proxy.meta_data = meta_out
|
|
except Exception as e:
|
|
raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}")
|
|
return proxy
|
|
|
|
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
|
|
if getattr(self, "_disable_module_getattr", False):
|
|
return attr_val
|
|
else:
|
|
# return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
|
|
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
|
|
for n, p in collection_to_search:
|
|
if attr_val is p:
|
|
if n not in parameter_proxy_cache:
|
|
kwargs = {}
|
|
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
|
|
kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else
|
|
lambda node: ParameterProxy(self, node, n, attr_val))
|
|
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
|
|
parameter_proxy_cache[n] = val_proxy
|
|
return parameter_proxy_cache[n]
|
|
return None
|
|
|
|
if isinstance(attr_val, torch.nn.Parameter):
|
|
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
|
|
parameter_proxy_cache)
|
|
if maybe_parameter_proxy is not None:
|
|
return maybe_parameter_proxy
|
|
|
|
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
|
|
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(),
|
|
parameter_proxy_cache)
|
|
if maybe_buffer_proxy is not None:
|
|
return maybe_buffer_proxy
|
|
|
|
return attr_val
|
|
|
|
def call_module(self, m, forward, args, kwargs):
|
|
self.orig_forward = forward
|
|
return super().call_module(m, forward, args, kwargs)
|
|
|
|
def proxy(self, node) -> Proxy:
|
|
"""
|
|
Returns a ColoProxy object.
|
|
"""
|
|
return self.proxy_cls(node, self)
|
|
|
|
def _configure_tracer_type(self, tracer_type: TracerType):
|
|
if tracer_type == TracerType.DEFAULT:
|
|
self.proxy_cls = Proxy
|
|
self.tracer_type = TracerType.DEFAULT
|
|
elif tracer_type == TracerType.META:
|
|
self.proxy_cls = ColoProxy
|
|
self.tracer_type = TracerType.META
|
|
else:
|
|
raise ValueError(f"Unrecognised tracer type {tracer_type}")
|
|
|
|
def trace(self,
|
|
root: nn.Module,
|
|
concrete_args: Optional[Dict[str, Tensor]] = None,
|
|
meta_args: Optional[Dict[str, Tensor]] = None) -> Graph:
|
|
"""
|
|
Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow.
|
|
|
|
Args:
|
|
root (nn.Module): a `nn.Module` object to trace the computation graph
|
|
meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph.
|
|
These arguments are the sample data fed to the model during actual computation, but just converted to meta tensors.
|
|
concrete_args (Optional[Dict[str, Tensor]]): the concrete arguments that should not be treated as Proxies.
|
|
"""
|
|
if meta_args is None:
|
|
meta_args = {}
|
|
|
|
if concrete_args is None:
|
|
concrete_args = {}
|
|
|
|
if len(meta_args) == 0:
|
|
self._configure_tracer_type(TracerType.DEFAULT)
|
|
else:
|
|
self._configure_tracer_type(TracerType.META)
|
|
|
|
# check concrete and meta args have valid names
|
|
sig = inspect.signature(root.forward)
|
|
sig_names = set(sig.parameters.keys())
|
|
meta_arg_names = set(meta_args.keys())
|
|
|
|
# update concrete args with default values
|
|
non_meta_arg_names = sig_names - meta_arg_names
|
|
for k, v in sig.parameters.items():
|
|
if k in non_meta_arg_names and \
|
|
k not in concrete_args and \
|
|
v.default is not inspect.Parameter.empty:
|
|
concrete_args[k] = v.default
|
|
|
|
# get non concrete arg names
|
|
concrete_arg_names = set(concrete_args.keys())
|
|
non_concrete_arg_names = sig_names - concrete_arg_names
|
|
|
|
def _check_arg_name_valid(names):
|
|
success, element = is_element_in_list(names, sig_names)
|
|
if not success:
|
|
raise KeyError(
|
|
f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function")
|
|
|
|
_check_arg_name_valid(meta_arg_names)
|
|
_check_arg_name_valid(concrete_arg_names)
|
|
|
|
# assign as attributed for late reference
|
|
def _check_kwargs(kwargs, should_be_meta: bool):
|
|
for k, v in kwargs.items():
|
|
if not should_be_meta:
|
|
assert not torch.is_tensor(v) or not v.is_meta, \
|
|
f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer'
|
|
else:
|
|
assert v.is_meta == should_be_meta, \
|
|
f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer'
|
|
|
|
_check_kwargs(concrete_args, should_be_meta=False)
|
|
_check_kwargs(meta_args, should_be_meta=True)
|
|
|
|
self.concrete_args = concrete_args
|
|
self.meta_args = meta_args
|
|
|
|
self.patched_torch_tensor_methods = {}
|
|
if self.tracer_type == TracerType.META:
|
|
# wrap the torch tensor constructing methods so that they are captured in the graph
|
|
self.patched_torch_tensor_methods = {
|
|
target: wrap_tensor_constructor_method(getattr(torch, target))
|
|
for target in self._TORCH_METHODS_TO_PATCH
|
|
}
|
|
|
|
# patch these methods to replace their original use
|
|
for name, (wrapper, orig) in self.patched_torch_tensor_methods.items():
|
|
setattr(torch, name, wrapper)
|
|
|
|
# cache these methods so that we can detect whether a method call
|
|
# should be patched during tracing
|
|
self.orig_torch_tensor_methods = [val[1] for val in self.patched_torch_tensor_methods.values()]
|
|
|
|
try:
|
|
self.graph = super().trace(root, concrete_args=concrete_args)
|
|
finally:
|
|
# recover the patched methods
|
|
for name, (_, orig) in self.patched_torch_tensor_methods.items():
|
|
setattr(torch, name, orig)
|
|
|
|
if self.tracer_type == TracerType.DEFAULT:
|
|
return self.graph
|
|
|
|
# This is necessary because concrete args are added as input to the traced module since
|
|
# https://github.com/pytorch/pytorch/pull/55888.
|
|
for node in self.graph.nodes:
|
|
if node.op == "placeholder":
|
|
# Removing default values for inputs as the forward pass will fail with them.
|
|
if node.target in non_concrete_arg_names:
|
|
node.args = ()
|
|
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
|
|
# It cannot infer on the attributes and methods the input should have, and fails.
|
|
node.type = torch.Tensor
|
|
# It is a concrete arg so it is not used and should be removed.
|
|
else:
|
|
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
|
|
# Newer versions of torch.fx emit an assert statement
|
|
# for concrete arguments; delete those before we delete
|
|
# the concrete arg.
|
|
to_delete = []
|
|
for user in node.users:
|
|
if user.target == torch.fx._symbolic_trace._assert_is_none:
|
|
to_delete.append(user)
|
|
for user in to_delete:
|
|
self.graph.erase_node(user)
|
|
|
|
self.graph.erase_node(node)
|
|
|
|
# TODO: solves GraphModule creation.
|
|
# Without this, return type annotation "Tuple" is causing code execution failure.
|
|
if node.op == "output":
|
|
node.type = None
|
|
|
|
return self.graph
|
|
|
|
|
|
def wrap_tensor_constructor_method(target):
|
|
|
|
def look_for_proxy(*args, **kwargs):
|
|
# find in pos vars
|
|
for arg in args:
|
|
if isinstance(arg, Proxy):
|
|
return arg
|
|
|
|
# find in keyword vars
|
|
for k, v in kwargs.items():
|
|
if isinstance(v, Proxy):
|
|
return v
|
|
return None
|
|
|
|
@functools.wraps(target)
|
|
def wrapper(*args, **kwargs):
|
|
proxy = look_for_proxy(*args, **kwargs)
|
|
|
|
if proxy is not None:
|
|
# if the arg is a proxy, then need to record this function called on this proxy
|
|
# e.g. torch.ones(size) where size is an input proxy
|
|
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
|
|
else:
|
|
# this is called directly when the inputs do not contain proxy
|
|
# e.g. torch.ones(4) where the input is static
|
|
return target(*args, **kwargs)
|
|
|
|
return wrapper, target
|