You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/fx/tracer/tracer.py

472 lines
20 KiB

#!/usr/bin/env python
"""
tracer.py:
Implemented a tracer which supports control flow and user-defined meta arguments.
The implementation is partly inspired HuggingFace's fx tracer
"""
import enum
import inspect
import functools
import operator
from contextlib import contextmanager
from colossalai.fx.tracer.meta_patch import meta_patched_module
import torch
import torch.nn as nn
from torch import Tensor
from torch.fx import Tracer, Node
from torch.fx.graph import Graph
from torch.fx.proxy import Proxy, ParameterProxy
from ..proxy import ColoProxy
from typing import Optional, Dict, Any
from ._tracer_utils import is_element_in_list, extract_meta, compute_meta_data_for_functions_proxy
from .meta_patch import meta_patched_function, meta_patched_module
from torch.fx.graph import magic_methods, reflectable_magic_methods
__all__ = ['ColoTracer']
class TracerType(enum.Enum):
DEFAULT = 1
META = 2
class ColoTracer(Tracer):
"""
ColoTracer is a symbolic tracer designed to support dynamic control flow by using meta tensors for the `colossalai.fx` module.
This tracer is initialized in the same way as the original torch.fx.Tracer.
Usage::
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 10)
self.linear2 = nn.Linear(10, 10)
def forward(self, x, y):
x1 = self.linear1(x)
y1 = self.linear2(y)
if x1.dim() == 2:
return x1 + y1
else:
return x1 - y1
model = Model()
tracer = ColoTracer()
graph = tracer.trace(model, concrete_args={'y': torch.rand(4, 10)}, meta_args={'x': torch.rand(4, 10, device='meta')})
"""
def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tracer_type = TracerType.META
self.proxy_cls = ColoProxy
# whether the tracer will record the usage of torch.utils.checkpoint
self.trace_act_ckpt = trace_act_ckpt
# whether the current tracing occurs within the activation checkpoint functions
self.inside_torch_checkpoint_func = False
self.act_ckpt_region_count = 0
# Feature flag for proxying accesses to buffer values
proxy_buffer_attributes: bool = True
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor", "finfo"]
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy:
"""
Create a proxy for different kinds of operations.
"""
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
if self.tracer_type == TracerType.DEFAULT:
# since meta_args is not given
# we just fall back to the original torch.fx.Tracer
return proxy
proxy: ColoProxy
if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta:
proxy.meta_data = self.meta_args[target]
return proxy
if target in self.orig_torch_tensor_methods:
# NOTE: tensor constructors in PyTorch define the `device` argument as
# *kwargs-only*. That is why this works. If you add methods to
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
# this will break and you will likely see issues where we cannot infer
# the size of the output.
if "device" in kwargs:
kwargs["device"] = "meta"
try:
args_metas, kwargs_metas = extract_meta(*args, **kwargs)
if kind == "call_function":
# fetch patched function
if meta_patched_function.has(target):
meta_target = meta_patched_function.get(target)
elif meta_patched_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
meta_target = meta_patched_function.get(target.__name__)
else:
meta_target = target
meta_out = meta_target(*args_metas, **kwargs_metas)
if isinstance(meta_out, torch.Tensor):
meta_out = meta_out.to(device="meta")
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
# fetch patched method
if meta_patched_function.has(method):
meta_target = meta_patched_function.get(method)
else:
meta_target = method
meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == "call_module":
if not hasattr(self, "orig_forward"):
raise AttributeError(f"{self} does not have an attribute called orig_forward")
self._disable_module_getattr = True
try:
mod = self.root.get_submodule(target)
mod_type = type(mod)
if meta_patched_module.has(mod_type):
meta_out = meta_patched_module.get(mod_type)(mod, *args_metas, **kwargs_metas)
else:
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
finally:
self._disable_module_getattr = False
elif kind == "get_attr":
self._disable_module_getattr = True
try:
attr_itr = self.root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.Tensor):
meta_out = attr_itr.to(device="meta")
else:
meta_out = attr_itr
finally:
self._disable_module_getattr = False
else:
return proxy
if not isinstance(proxy, Proxy):
raise ValueError("Don't support composite output yet")
proxy.meta_data = meta_out
except Exception as e:
raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}")
return proxy
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, "_disable_module_getattr", False):
return attr_val
else:
# return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
for n, p in collection_to_search:
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else
lambda node: ParameterProxy(self, node, n, attr_val))
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
parameter_proxy_cache)
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(),
parameter_proxy_cache)
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
return attr_val
def call_module(self, m, forward, args, kwargs):
self.orig_forward = forward
module_qualified_name = self.path_of_module(m)
# a leaf module is the torch.nn.Module subclasses starting with `torch.nn`
# which means customized modules are not leaf module by default
# if a customized or third-party module like apex.normalization.FusedRMSNorm is patched,
# we should treat it as leaf module as well
if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name):
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
else:
return forward(*args, **kwargs)
def proxy(self, node) -> Proxy:
"""
Returns a ColoProxy object.
"""
return self.proxy_cls(node, self)
def _configure_tracer_type(self, tracer_type: TracerType):
if tracer_type == TracerType.DEFAULT:
self.proxy_cls = Proxy
self.tracer_type = TracerType.DEFAULT
elif tracer_type == TracerType.META:
self.proxy_cls = ColoProxy
self.tracer_type = TracerType.META
else:
raise ValueError(f"Unrecognised tracer type {tracer_type}")
def trace(self,
root: nn.Module,
concrete_args: Optional[Dict[str, Tensor]] = None,
meta_args: Optional[Dict[str, Tensor]] = None) -> Graph:
"""
Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow.
Args:
root (nn.Module): a `nn.Module` object to trace the computation graph
meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph.
These arguments are the sample data fed to the model during actual computation, but just converted to meta tensors.
concrete_args (Optional[Dict[str, Tensor]]): the concrete arguments that should not be treated as Proxies.
"""
if meta_args is None:
meta_args = {}
if concrete_args is None:
concrete_args = {}
if len(meta_args) == 0:
self._configure_tracer_type(TracerType.DEFAULT)
else:
self._configure_tracer_type(TracerType.META)
# check concrete and meta args have valid names
sig = inspect.signature(root.forward)
sig_names = set(sig.parameters.keys())
meta_arg_names = set(meta_args.keys())
# update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items():
if k in non_meta_arg_names and \
k not in concrete_args and \
v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
# get non concrete arg names
concrete_arg_names = set(concrete_args.keys())
non_concrete_arg_names = sig_names - concrete_arg_names
def _check_arg_name_valid(names):
success, element = is_element_in_list(names, sig_names)
if not success:
raise KeyError(
f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function")
_check_arg_name_valid(meta_arg_names)
_check_arg_name_valid(concrete_arg_names)
# assign as attributed for late reference
def _check_kwargs(kwargs, should_be_meta: bool):
for k, v in kwargs.items():
if not should_be_meta:
assert not torch.is_tensor(v) or not v.is_meta, \
f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer'
else:
assert v.is_meta == should_be_meta, \
f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer'
_check_kwargs(concrete_args, should_be_meta=False)
_check_kwargs(meta_args, should_be_meta=True)
self.concrete_args = concrete_args
self.meta_args = meta_args
self.patched_torch_tensor_methods = {}
if self.tracer_type == TracerType.META:
# wrap the torch tensor constructing methods so that they are captured in the graph
self.patched_torch_tensor_methods = {
target: wrap_tensor_constructor_method(getattr(torch, target))
for target in self._TORCH_METHODS_TO_PATCH
}
# patch these methods to replace their original use
for name, (wrapper, orig) in self.patched_torch_tensor_methods.items():
setattr(torch, name, wrapper)
# cache these methods so that we can detect whether a method call
# should be patched during tracing
self.orig_torch_tensor_methods = [val[1] for val in self.patched_torch_tensor_methods.values()]
try:
# to track the usage of torch.utils.checkpoint
with self.trace_activation_checkpoint(enabled=self.trace_act_ckpt):
self.graph = super().trace(root, concrete_args=concrete_args)
finally:
# recover the patched methods
for name, (_, orig) in self.patched_torch_tensor_methods.items():
setattr(torch, name, orig)
if self.tracer_type == TracerType.DEFAULT:
return self.graph
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
for node in self.graph.nodes:
if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them.
if node.target in non_concrete_arg_names:
node.args = ()
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
# It cannot infer on the attributes and methods the input should have, and fails.
node.type = torch.Tensor
# It is a concrete arg so it is not used and should be removed.
else:
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
# Newer versions of torch.fx emit an assert statement
# for concrete arguments; delete those before we delete
# the concrete arg.
to_delete = []
for user in node.users:
if user.target == torch.fx._symbolic_trace._assert_is_none:
to_delete.append(user)
for user in to_delete:
self.graph.erase_node(user)
self.graph.erase_node(node)
# TODO: solves GraphModule creation.
# Without this, return type annotation "Tuple" is causing code execution failure.
if node.op == "output":
node.type = None
return self.graph
@contextmanager
def trace_activation_checkpoint(self, enabled: bool):
if enabled:
orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction
class PatchedCheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
# signal that the current tracing occurs within activaton checkpoint part
self.inside_torch_checkpoint_func = True
out = run_function(*args)
self.inside_torch_checkpoint_func = False
self.act_ckpt_region_count += 1
return out
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError(
"We do not implement the backward pass as we only trace the forward pass.")
# override the checkpoint function
torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction
yield
if enabled:
# recover the checkpoint function upon exit
torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func
def create_node(self, *args, **kwargs) -> Node:
node = super().create_node(*args, **kwargs)
if self.inside_torch_checkpoint_func:
# annotate the activation checkpoint module
setattr(node, 'activation_checkpoint', self.act_ckpt_region_count)
return node
def wrap_tensor_constructor_method(target):
def look_for_proxy(*args, **kwargs):
# find in pos vars
for arg in args:
if isinstance(arg, Proxy):
return arg
if isinstance(arg, (tuple, list)):
return look_for_proxy(*arg)
# find in keyword vars
for k, v in kwargs.items():
if isinstance(v, Proxy):
return v
if isinstance(v, (tuple, list)):
return look_for_proxy(*v)
return None
@functools.wraps(target)
def wrapper(*args, **kwargs):
proxy = look_for_proxy(*args, **kwargs)
if proxy is not None:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
colo_proxy = proxy.tracer.create_proxy("call_function", target, args, kwargs)
if not isinstance(colo_proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
colo_proxy = ColoProxy(proxy.node)
colo_proxy.meta_data = meta_out
return colo_proxy
else:
# this is called directly when the inputs do not contain proxy
# e.g. torch.ones(4) where the input is static
return target(*args, **kwargs)
return wrapper, target
# Patched magic methods for ColoProxy, then tracer could record the magic_method like __sub__,
# and add meta_data attribute to the created proxy.
for method in magic_methods:
def _scope(method):
def impl(*args, **kwargs):
tracer = args[0].tracer
target = getattr(operator, method)
proxy = tracer.create_proxy('call_function', target, args, kwargs)
if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
return proxy
impl.__name__ = method
as_magic = f'__{method.strip("_")}__'
setattr(ColoProxy, as_magic, impl)
_scope(method)
def _define_reflectable(orig_method_name):
method_name = f'__r{orig_method_name.strip("_")}__'
def impl(self, rhs):
target = getattr(operator, orig_method_name)
proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {})
if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {})
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
return proxy
impl.__name__ = method_name
impl.__qualname__ = method_name
setattr(ColoProxy, method_name, impl)
for orig_method_name in reflectable_magic_methods:
_define_reflectable(orig_method_name)