ColossalAI/colossalai/_analyzer/fx/tracer/tracer.py

364 lines
15 KiB
Python
Raw Normal View History

import functools
import inspect
from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, Type, Union
import torch
import torch.nn as nn
from torch.fx import Graph, Node, Proxy, Tracer
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import _TensorPropertyMethod, _TorchFactoryMethod
from ..node_util import MetaInfo
from .proxy import ColoProxy
Target = Union[Callable[..., Any], str]
def _truncate_suffix(s: str):
import re
# FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name
return re.sub(r'_\d+$', '', s)
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
def wrapper(impl):
assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}"
getattr(ColoTracer, name)[func] = impl
return impl
return wrapper
def register_leaf_module_impl(module: nn.Module):
def wrapper(impl):
ColoTracer._custom_leaf_module_impl[module] = impl
return impl
return wrapper
def register_leaf_module(module: nn.Module):
ColoTracer._custom_leaf_module.add(module)
def register_non_leaf_module(module: nn.Module):
ColoTracer._custom_non_leaf_module.add(module)
class ColoTracer(Tracer):
_custom_leaf_module: Set[Type[nn.Module]] = set()
_custom_leaf_module_impl: Dict[Type[nn.Module], Callable[..., Any]] = {}
_custom_non_leaf_module: Set[Type[nn.Module]] = set()
_custom_impl: Dict[Callable[..., Any], Callable[..., Any]] = {}
_bias_addition_impl: Dict[Callable[..., Any], Callable[..., Any]] = {}
_bias_addition_module = [
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d,
]
def __init__(self, trace_act_ckpt: bool = False, bias_addition_split: bool = False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.disable_module_getattr = False
self.proxy_buffer_attributes = True
# whether the tracer will record the usage of torch.utils.checkpoint
self.trace_act_ckpt = trace_act_ckpt
self.ckpt_regions = []
self.ckpt_idx = 0
self.mod_dir = ''
# whether the tracer should split the bias_add ops into two ops
self.bias_addition_split = bias_addition_split
def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
# if bias-addiction split is enabled, and module has bias, then it is not a leaf module
# we will enter the module and split the bias-addition ops
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
return False
# user can specify which modules are leaf modules and which are not
return (type(m) not in self._custom_non_leaf_module
and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...],
kwargs: Dict[str, Any]) -> Any:
curr_dir = self.mod_dir
self.mod_dir = 'self.' + self.path_of_module(m)
rst = super().call_module(m, forward, args, kwargs)
self.mod_dir = curr_dir
return rst
def proxy(self, node: Node) -> 'ColoProxy':
return ColoProxy(node, self)
def create_proxy(self,
kind: str,
target: Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
if kind == 'placeholder':
proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
_truncate_suffix(target), None)
elif kind == 'get_attr':
self.disable_module_getattr = True
try:
attr_itr = self.root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
proxy.meta_data = attr_itr
finally:
self.disable_module_getattr = False
elif kind == 'call_function':
proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
elif kind == 'call_method':
self.disable_module_getattr = True
try:
if target == '__call__':
proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else:
if target not in _TensorPropertyMethod:
proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
**tree_map(unwrap_fn, kwargs))
finally:
self.disable_module_getattr = False
elif kind == 'call_module':
mod = self.root.get_submodule(target)
self.disable_module_getattr = True
try:
args = tree_map(unwrap_fn, args)
kwargs = tree_map(unwrap_fn, kwargs)
if type(mod) in self._custom_leaf_module:
target = self._custom_leaf_module_impl[type(mod)]
proxy.meta_data = target(mod, *args, **kwargs)
else:
proxy.meta_data = mod.forward(*args, **kwargs)
finally:
self.disable_module_getattr = False
return proxy
def create_node(self, *args, **kwargs) -> Node:
node = super().create_node(*args, **kwargs)
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
return node
def trace(self,
root: torch.nn.Module,
concrete_args: Optional[Dict[str, torch.Tensor]] = None,
meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
if meta_args is None:
meta_args = {}
if concrete_args is None:
concrete_args = {}
# check concrete and meta args have valid names
sig = inspect.signature(root.forward)
sig_names = set(sig.parameters.keys())
meta_arg_names = set(meta_args.keys())
concrete_arg_names = set(concrete_args.keys())
non_concrete_arg_names = sig_names - concrete_arg_names
# update concrete args with default values
for k, v in sig.parameters.items():
if k in sig_names - meta_arg_names and \
k not in concrete_args and \
v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
def _check_arg_name_valid(names: Iterable[str]):
for name in names:
if name not in sig_names:
raise ValueError(f"Argument {name} is not in the signature of {root.__class__.__name__}.forward")
_check_arg_name_valid(meta_arg_names)
_check_arg_name_valid(concrete_arg_names)
self.concrete_args = concrete_args
self.meta_args = meta_args
with self._torch_factory_override(), self._tracer_override(), torch.no_grad():
self.mod_dir = 'self'
self.graph = super().trace(root, concrete_args=concrete_args)
self.mod_dir = ''
self.graph.lint()
for node in self.graph.nodes:
if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them.
if node.target in non_concrete_arg_names:
node.args = ()
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
# It cannot infer on the attributes and methods the input should have, and fails.
node.type = torch.Tensor
# It is a concrete arg so it is not used and should be removed.
else:
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
# Newer versions of torch.fx emit an assert statement
# for concrete arguments; delete those before we delete
# the concrete arg.
to_delete = []
for user in node.users:
if user.target == torch.fx._symbolic_trace._assert_is_none:
to_delete.append(user)
for user in to_delete:
self.graph.erase_node(user)
self.graph.erase_node(node)
# TODO: solves GraphModule creation.
# Without this, return type annotation "Tuple" is causing code execution failure.
if node.op == "output":
node.type = None
return self.graph
@contextmanager
def _tracer_override(self):
# override the tracer to support custom modules and checkpointing
if self.trace_act_ckpt:
orig_ckpt_func_apply = torch.utils.checkpoint.CheckpointFunction.apply
orig_ckpt_func_without_reentrant = torch.utils.checkpoint._checkpoint_without_reentrant
def checkpoint(run_function, preserve_rng_state=False, *args):
self.ckpt_regions.append(self.ckpt_idx)
out = run_function(*args)
self.ckpt_idx = self.ckpt_regions.pop(-1) + 1
return out
# override the checkpoint function
torch.utils.checkpoint.CheckpointFunction.apply = checkpoint
torch.utils.checkpoint._checkpoint_without_reentrant = checkpoint
# override the custom functions
ColoProxy._func_dispatch.update({k: v for k, v in self._custom_impl.items()})
# override the bias addition functions
if self.bias_addition_split:
ColoProxy._func_dispatch.update({k: v for k, v in self._bias_addition_impl.items()})
yield
if self.trace_act_ckpt:
# recover the checkpoint function upon exit
torch.utils.checkpoint.CheckpointFunction.apply = orig_ckpt_func_apply
torch.utils.checkpoint._checkpoint_reentrant = orig_ckpt_func_without_reentrant
ColoProxy._func_dispatch = {}
@contextmanager
def _torch_factory_override(self):
# override the torch factory functions to create a proxy when the method
# is called during ``symbolic_trace()``.
def wrap_factory_method(target):
@functools.wraps(target)
def wrapper(*args, **kwargs):
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
isinstance(p, ColoProxy) for p in kwargs.values())
if is_proxy:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
self.disable_module_getattr = True
try:
proxy = self.create_proxy('call_function', target, args, kwargs)
finally:
self.disable_module_getattr = False
return proxy
else:
return target(*args, **kwargs)
return wrapper, target
overrides = {
target: wrap_factory_method(getattr(torch, target))
for target in _TorchFactoryMethod
if callable(getattr(torch, target))
}
for name, (wrapper, orig) in overrides.items():
setattr(torch, name, wrapper)
yield
# recover the torch factory functions upon exit
for name, (wrapper, orig) in overrides.items():
setattr(torch, name, orig)
def _post_check(self, non_concrete_arg_names: Set[str]):
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
for node in self.graph.nodes:
if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them.
if node.target in non_concrete_arg_names:
node.args = ()
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
# It cannot infer on the attributes and methods the input should have, and fails.
node.type = torch.Tensor
# It is a concrete arg so it is not used and should be removed.
else:
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
# Newer versions of torch.fx emit an assert statement
# for concrete arguments; delete those before we delete
# the concrete arg.
to_delete = []
for user in node.users:
if user.target == torch.fx._symbolic_trace._assert_is_none:
to_delete.append(user)
for user in to_delete:
self.graph.erase_node(user)
self.graph.erase_node(node)
if node.op == "output":
node.type = None
self.graph.lint()
def getattr(self, attr, attr_val, parameter_proxy_cache):
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, "disable_module_getattr", False):
return attr_val
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
for n, p in collection_to_search:
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters:
kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else
lambda node: ColoProxy(self, node, n, attr_val))
val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), parameter_proxy_cache)
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
parameter_proxy_cache)
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
return attr_val