mirror of https://github.com/hpcaitech/ColossalAI
364 lines
15 KiB
Python
364 lines
15 KiB
Python
import functools
|
|
import inspect
|
|
from contextlib import contextmanager
|
|
from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, Type, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.fx import Graph, Node, Proxy, Tracer
|
|
from torch.utils._pytree import tree_map
|
|
|
|
from colossalai._analyzer._subclasses import _TensorPropertyMethod, _TorchFactoryMethod
|
|
|
|
from ..node_util import MetaInfo
|
|
from .proxy import ColoProxy
|
|
|
|
Target = Union[Callable[..., Any], str]
|
|
|
|
|
|
def _truncate_suffix(s: str):
|
|
import re
|
|
|
|
# FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name
|
|
return re.sub(r'_\d+$', '', s)
|
|
|
|
|
|
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
|
|
|
|
def wrapper(impl):
|
|
assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}"
|
|
getattr(ColoTracer, name)[func] = impl
|
|
return impl
|
|
|
|
return wrapper
|
|
|
|
|
|
def register_leaf_module_impl(module: nn.Module):
|
|
|
|
def wrapper(impl):
|
|
ColoTracer._custom_leaf_module_impl[module] = impl
|
|
return impl
|
|
|
|
return wrapper
|
|
|
|
|
|
def register_leaf_module(module: nn.Module):
|
|
ColoTracer._custom_leaf_module.add(module)
|
|
|
|
|
|
def register_non_leaf_module(module: nn.Module):
|
|
ColoTracer._custom_non_leaf_module.add(module)
|
|
|
|
|
|
class ColoTracer(Tracer):
|
|
_custom_leaf_module: Set[Type[nn.Module]] = set()
|
|
_custom_leaf_module_impl: Dict[Type[nn.Module], Callable[..., Any]] = {}
|
|
_custom_non_leaf_module: Set[Type[nn.Module]] = set()
|
|
_custom_impl: Dict[Callable[..., Any], Callable[..., Any]] = {}
|
|
_bias_addition_impl: Dict[Callable[..., Any], Callable[..., Any]] = {}
|
|
_bias_addition_module = [
|
|
torch.nn.Linear,
|
|
torch.nn.Conv1d,
|
|
torch.nn.Conv2d,
|
|
torch.nn.Conv3d,
|
|
torch.nn.ConvTranspose1d,
|
|
torch.nn.ConvTranspose2d,
|
|
torch.nn.ConvTranspose3d,
|
|
]
|
|
|
|
def __init__(self, trace_act_ckpt: bool = False, bias_addition_split: bool = False, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.disable_module_getattr = False
|
|
self.proxy_buffer_attributes = True
|
|
|
|
# whether the tracer will record the usage of torch.utils.checkpoint
|
|
self.trace_act_ckpt = trace_act_ckpt
|
|
self.ckpt_regions = []
|
|
self.ckpt_idx = 0
|
|
|
|
self.mod_dir = ''
|
|
|
|
# whether the tracer should split the bias_add ops into two ops
|
|
self.bias_addition_split = bias_addition_split
|
|
|
|
def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
|
|
# if bias-addiction split is enabled, and module has bias, then it is not a leaf module
|
|
# we will enter the module and split the bias-addition ops
|
|
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
|
|
return False
|
|
# user can specify which modules are leaf modules and which are not
|
|
return (type(m) not in self._custom_non_leaf_module
|
|
and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
|
|
|
|
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...],
|
|
kwargs: Dict[str, Any]) -> Any:
|
|
curr_dir = self.mod_dir
|
|
self.mod_dir = 'self.' + self.path_of_module(m)
|
|
rst = super().call_module(m, forward, args, kwargs)
|
|
self.mod_dir = curr_dir
|
|
return rst
|
|
|
|
def proxy(self, node: Node) -> 'ColoProxy':
|
|
return ColoProxy(node, self)
|
|
|
|
def create_proxy(self,
|
|
kind: str,
|
|
target: Target,
|
|
args: Tuple[Any, ...],
|
|
kwargs: Dict[str, Any],
|
|
name: Optional[str] = None,
|
|
type_expr: Optional[Any] = None,
|
|
proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
|
|
|
|
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
|
|
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
|
|
if kind == 'placeholder':
|
|
proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
|
|
_truncate_suffix(target), None)
|
|
elif kind == 'get_attr':
|
|
self.disable_module_getattr = True
|
|
try:
|
|
attr_itr = self.root
|
|
atoms = target.split(".")
|
|
for atom in atoms:
|
|
attr_itr = getattr(attr_itr, atom)
|
|
proxy.meta_data = attr_itr
|
|
finally:
|
|
self.disable_module_getattr = False
|
|
elif kind == 'call_function':
|
|
proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
|
|
elif kind == 'call_method':
|
|
self.disable_module_getattr = True
|
|
try:
|
|
if target == '__call__':
|
|
proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
|
|
else:
|
|
if target not in _TensorPropertyMethod:
|
|
proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
|
|
**tree_map(unwrap_fn, kwargs))
|
|
finally:
|
|
self.disable_module_getattr = False
|
|
elif kind == 'call_module':
|
|
mod = self.root.get_submodule(target)
|
|
self.disable_module_getattr = True
|
|
try:
|
|
args = tree_map(unwrap_fn, args)
|
|
kwargs = tree_map(unwrap_fn, kwargs)
|
|
if type(mod) in self._custom_leaf_module:
|
|
target = self._custom_leaf_module_impl[type(mod)]
|
|
proxy.meta_data = target(mod, *args, **kwargs)
|
|
else:
|
|
proxy.meta_data = mod.forward(*args, **kwargs)
|
|
finally:
|
|
self.disable_module_getattr = False
|
|
return proxy
|
|
|
|
def create_node(self, *args, **kwargs) -> Node:
|
|
node = super().create_node(*args, **kwargs)
|
|
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
|
|
return node
|
|
|
|
def trace(self,
|
|
root: torch.nn.Module,
|
|
concrete_args: Optional[Dict[str, torch.Tensor]] = None,
|
|
meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
|
|
|
|
if meta_args is None:
|
|
meta_args = {}
|
|
|
|
if concrete_args is None:
|
|
concrete_args = {}
|
|
|
|
# check concrete and meta args have valid names
|
|
sig = inspect.signature(root.forward)
|
|
sig_names = set(sig.parameters.keys())
|
|
meta_arg_names = set(meta_args.keys())
|
|
concrete_arg_names = set(concrete_args.keys())
|
|
non_concrete_arg_names = sig_names - concrete_arg_names
|
|
# update concrete args with default values
|
|
for k, v in sig.parameters.items():
|
|
if k in sig_names - meta_arg_names and \
|
|
k not in concrete_args and \
|
|
v.default is not inspect.Parameter.empty:
|
|
concrete_args[k] = v.default
|
|
|
|
def _check_arg_name_valid(names: Iterable[str]):
|
|
for name in names:
|
|
if name not in sig_names:
|
|
raise ValueError(f"Argument {name} is not in the signature of {root.__class__.__name__}.forward")
|
|
|
|
_check_arg_name_valid(meta_arg_names)
|
|
_check_arg_name_valid(concrete_arg_names)
|
|
|
|
self.concrete_args = concrete_args
|
|
self.meta_args = meta_args
|
|
|
|
with self._torch_factory_override(), self._tracer_override(), torch.no_grad():
|
|
self.mod_dir = 'self'
|
|
self.graph = super().trace(root, concrete_args=concrete_args)
|
|
self.mod_dir = ''
|
|
self.graph.lint()
|
|
|
|
for node in self.graph.nodes:
|
|
if node.op == "placeholder":
|
|
# Removing default values for inputs as the forward pass will fail with them.
|
|
if node.target in non_concrete_arg_names:
|
|
node.args = ()
|
|
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
|
|
# It cannot infer on the attributes and methods the input should have, and fails.
|
|
node.type = torch.Tensor
|
|
# It is a concrete arg so it is not used and should be removed.
|
|
else:
|
|
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
|
|
# Newer versions of torch.fx emit an assert statement
|
|
# for concrete arguments; delete those before we delete
|
|
# the concrete arg.
|
|
to_delete = []
|
|
for user in node.users:
|
|
if user.target == torch.fx._symbolic_trace._assert_is_none:
|
|
to_delete.append(user)
|
|
for user in to_delete:
|
|
self.graph.erase_node(user)
|
|
|
|
self.graph.erase_node(node)
|
|
|
|
# TODO: solves GraphModule creation.
|
|
# Without this, return type annotation "Tuple" is causing code execution failure.
|
|
if node.op == "output":
|
|
node.type = None
|
|
return self.graph
|
|
|
|
@contextmanager
|
|
def _tracer_override(self):
|
|
# override the tracer to support custom modules and checkpointing
|
|
if self.trace_act_ckpt:
|
|
orig_ckpt_func_apply = torch.utils.checkpoint.CheckpointFunction.apply
|
|
orig_ckpt_func_without_reentrant = torch.utils.checkpoint._checkpoint_without_reentrant
|
|
|
|
def checkpoint(run_function, preserve_rng_state=False, *args):
|
|
self.ckpt_regions.append(self.ckpt_idx)
|
|
out = run_function(*args)
|
|
self.ckpt_idx = self.ckpt_regions.pop(-1) + 1
|
|
return out
|
|
|
|
# override the checkpoint function
|
|
torch.utils.checkpoint.CheckpointFunction.apply = checkpoint
|
|
torch.utils.checkpoint._checkpoint_without_reentrant = checkpoint
|
|
|
|
# override the custom functions
|
|
ColoProxy._func_dispatch.update({k: v for k, v in self._custom_impl.items()})
|
|
|
|
# override the bias addition functions
|
|
if self.bias_addition_split:
|
|
ColoProxy._func_dispatch.update({k: v for k, v in self._bias_addition_impl.items()})
|
|
|
|
yield
|
|
|
|
if self.trace_act_ckpt:
|
|
# recover the checkpoint function upon exit
|
|
torch.utils.checkpoint.CheckpointFunction.apply = orig_ckpt_func_apply
|
|
torch.utils.checkpoint._checkpoint_reentrant = orig_ckpt_func_without_reentrant
|
|
|
|
ColoProxy._func_dispatch = {}
|
|
|
|
@contextmanager
|
|
def _torch_factory_override(self):
|
|
# override the torch factory functions to create a proxy when the method
|
|
# is called during ``symbolic_trace()``.
|
|
def wrap_factory_method(target):
|
|
|
|
@functools.wraps(target)
|
|
def wrapper(*args, **kwargs):
|
|
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
|
|
isinstance(p, ColoProxy) for p in kwargs.values())
|
|
if is_proxy:
|
|
# if the arg is a proxy, then need to record this function called on this proxy
|
|
# e.g. torch.ones(size) where size is an input proxy
|
|
self.disable_module_getattr = True
|
|
try:
|
|
proxy = self.create_proxy('call_function', target, args, kwargs)
|
|
finally:
|
|
self.disable_module_getattr = False
|
|
return proxy
|
|
else:
|
|
return target(*args, **kwargs)
|
|
|
|
return wrapper, target
|
|
|
|
overrides = {
|
|
target: wrap_factory_method(getattr(torch, target))
|
|
for target in _TorchFactoryMethod
|
|
if callable(getattr(torch, target))
|
|
}
|
|
for name, (wrapper, orig) in overrides.items():
|
|
setattr(torch, name, wrapper)
|
|
|
|
yield
|
|
|
|
# recover the torch factory functions upon exit
|
|
for name, (wrapper, orig) in overrides.items():
|
|
setattr(torch, name, orig)
|
|
|
|
def _post_check(self, non_concrete_arg_names: Set[str]):
|
|
# This is necessary because concrete args are added as input to the traced module since
|
|
# https://github.com/pytorch/pytorch/pull/55888.
|
|
for node in self.graph.nodes:
|
|
if node.op == "placeholder":
|
|
# Removing default values for inputs as the forward pass will fail with them.
|
|
if node.target in non_concrete_arg_names:
|
|
node.args = ()
|
|
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
|
|
# It cannot infer on the attributes and methods the input should have, and fails.
|
|
node.type = torch.Tensor
|
|
# It is a concrete arg so it is not used and should be removed.
|
|
else:
|
|
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
|
|
# Newer versions of torch.fx emit an assert statement
|
|
# for concrete arguments; delete those before we delete
|
|
# the concrete arg.
|
|
to_delete = []
|
|
for user in node.users:
|
|
if user.target == torch.fx._symbolic_trace._assert_is_none:
|
|
to_delete.append(user)
|
|
for user in to_delete:
|
|
self.graph.erase_node(user)
|
|
|
|
self.graph.erase_node(node)
|
|
|
|
if node.op == "output":
|
|
node.type = None
|
|
self.graph.lint()
|
|
|
|
def getattr(self, attr, attr_val, parameter_proxy_cache):
|
|
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
|
|
|
|
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
|
|
if getattr(self, "disable_module_getattr", False):
|
|
return attr_val
|
|
|
|
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
|
|
for n, p in collection_to_search:
|
|
if attr_val is p:
|
|
if n not in parameter_proxy_cache:
|
|
kwargs = {}
|
|
if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters:
|
|
kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else
|
|
lambda node: ColoProxy(self, node, n, attr_val))
|
|
val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type]
|
|
parameter_proxy_cache[n] = val_proxy
|
|
return parameter_proxy_cache[n]
|
|
return None
|
|
|
|
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
|
|
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), parameter_proxy_cache)
|
|
if maybe_buffer_proxy is not None:
|
|
return maybe_buffer_proxy
|
|
|
|
if isinstance(attr_val, torch.nn.Parameter):
|
|
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
|
|
parameter_proxy_cache)
|
|
if maybe_parameter_proxy is not None:
|
|
return maybe_parameter_proxy
|
|
|
|
return attr_val
|