2023-03-10 05:21:05 +00:00
|
|
|
"""``torch.fx.ShapeProp``, but with ``MetaTensor``"""
|
|
|
|
|
|
|
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.fx
|
|
|
|
from torch.autograd.graph import saved_tensors_hooks
|
|
|
|
from torch.utils._pytree import tree_map
|
|
|
|
|
|
|
|
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
|
|
|
|
from colossalai._analyzer.fx.node_util import MetaInfo
|
|
|
|
from colossalai.fx._compatibility import compatibility
|
|
|
|
|
|
|
|
Target = Union[Callable[..., Any], str]
|
|
|
|
|
|
|
|
|
|
|
|
class sim_env(saved_tensors_hooks):
|
|
|
|
"""
|
|
|
|
A simulation of memory allocation and deallocation in the forward pass
|
|
|
|
using ``saved_tensor_hooks``.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
ctx (Dict[int, torch.Tensor]): A dictionary that maps the
|
|
|
|
data pointer of a tensor to the tensor itself. This is used
|
|
|
|
to track the memory allocation and deallocation.
|
|
|
|
|
|
|
|
param_ctx (Dict[int, torch.Tensor]): A dictionary that maps the
|
|
|
|
data pointer of all model parameters to the parameter itself.
|
|
|
|
This avoids overestimating the memory usage of the intermediate activations.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, module: Optional[torch.nn.Module] = None):
|
|
|
|
super().__init__(self.pack_hook, self.unpack_hook)
|
|
|
|
self.ctx = {}
|
|
|
|
self.param_ctx = {param.data_ptr(): param for param in module.parameters()}
|
|
|
|
self.buffer_ctx = {buffer.data_ptr(): buffer for buffer in module.buffers()} if module else {}
|
|
|
|
|
|
|
|
def pack_hook(self, tensor: torch.Tensor):
|
|
|
|
if tensor.data_ptr() not in self.param_ctx and tensor.data_ptr() not in self.buffer_ctx:
|
|
|
|
self.ctx[tensor.data_ptr()] = tensor
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
def unpack_hook(self, tensor):
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
|
|
|
def _normalize_tuple(x):
|
|
|
|
if not isinstance(x, tuple):
|
|
|
|
return (x,)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def _current_device(module):
|
2023-03-30 09:47:24 +00:00
|
|
|
try:
|
|
|
|
return next(module.parameters()).device
|
|
|
|
except StopIteration:
|
2023-09-19 06:20:26 +00:00
|
|
|
return torch.device("cpu")
|
2023-03-10 05:21:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
|
class ShapeProp(torch.fx.Interpreter):
|
|
|
|
"""
|
|
|
|
Execute an FX graph Node-by-Node and record the meta data of the result
|
|
|
|
into the corresponding node.
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
>>> model = MyModule()
|
|
|
|
>>> x = torch.rand(10, 10)
|
|
|
|
>>> gm = colossalai.fx.symbolic_trace(model, meta_args = {'x': x})
|
|
|
|
>>> interp = ShapeProp(gm)
|
|
|
|
>>> interp.propagate(x)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
module (GraphModule): The module to be executed
|
|
|
|
|
|
|
|
Hints:
|
|
|
|
If you want to add a new shape propagation rule, you can do so by
|
|
|
|
adding a new method to this class with the ``@register_shape_impl``
|
|
|
|
decorator. The method should take (*args, **kwargs) instance as its
|
|
|
|
input and generate output.
|
|
|
|
|
|
|
|
For example, if you want to add a shape propagation rule for
|
|
|
|
``torch.nn.functional.linear``, you can do so by adding a new method
|
|
|
|
to this class with the ``@register_shape_impl`` decorator (Since the
|
|
|
|
``MetaTensorMode`` is compatible with ``torch.nn.functional.linear``,
|
|
|
|
in practice you don't have to do as follows):
|
|
|
|
|
|
|
|
>>> @register_shape_impl(torch.nn.functional.linear)
|
|
|
|
>>> def linear_shape_impl(*args, **kwargs):
|
|
|
|
>>> # do something here
|
|
|
|
>>> return torch.empty(output_shape, device=output_device)
|
|
|
|
"""
|
2023-09-19 06:20:26 +00:00
|
|
|
|
2023-03-10 05:21:05 +00:00
|
|
|
_custom_dispatch_func = {}
|
|
|
|
_mode = MetaTensorMode()
|
|
|
|
|
|
|
|
def __init__(self, module: torch.fx.GraphModule, garbage_collect_values: bool = True):
|
|
|
|
super().__init__(module, garbage_collect_values)
|
|
|
|
self.global_hook = sim_env(module=self.module)
|
|
|
|
|
|
|
|
def run_node(self, n: torch.fx.Node) -> Any:
|
|
|
|
"""
|
|
|
|
Run a specific node ``n`` and return the result. Attach
|
|
|
|
(
|
|
|
|
``inputs``, ``outputs``, ``parameters``, ``buffers``
|
|
|
|
) to ``n``.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
n (Node): The ``Node`` to execute
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Any: The result of executing ``n``
|
|
|
|
"""
|
|
|
|
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
|
|
|
with self.global_hook:
|
|
|
|
r = getattr(self, n.op)(n.target, args, kwargs)
|
|
|
|
|
2023-03-22 02:40:33 +00:00
|
|
|
def unwrap_fn(elem):
|
|
|
|
def _convert_meta(t: torch.Tensor):
|
2023-09-19 06:20:26 +00:00
|
|
|
if t.device == "meta":
|
2023-03-22 02:40:33 +00:00
|
|
|
return t
|
|
|
|
else:
|
2023-09-19 06:20:26 +00:00
|
|
|
return t.to("meta")
|
2023-03-22 02:40:33 +00:00
|
|
|
|
|
|
|
if isinstance(elem, MetaTensor):
|
2023-09-19 06:20:26 +00:00
|
|
|
if getattr(self, "_is_param", False):
|
2023-03-30 09:47:24 +00:00
|
|
|
return torch.nn.Parameter(_convert_meta(elem._tensor))
|
2023-03-22 02:40:33 +00:00
|
|
|
return _convert_meta(elem._tensor)
|
|
|
|
|
|
|
|
elif isinstance(elem, torch.Tensor):
|
2023-03-30 09:47:24 +00:00
|
|
|
if isinstance(elem, torch.nn.Parameter):
|
|
|
|
return torch.nn.Parameter(_convert_meta(elem))
|
2023-03-22 02:40:33 +00:00
|
|
|
return _convert_meta(elem)
|
|
|
|
|
|
|
|
else:
|
|
|
|
return elem
|
|
|
|
|
2023-03-10 05:21:05 +00:00
|
|
|
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
|
|
|
|
n_info = MetaInfo(n)
|
|
|
|
n_info.outputs = _normalize_tuple(r)
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
if n.op == "call_module":
|
2023-03-10 05:21:05 +00:00
|
|
|
submod = self.fetch_attr(n.target)
|
|
|
|
n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()})
|
|
|
|
n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()})
|
|
|
|
|
|
|
|
else:
|
2023-09-19 06:20:26 +00:00
|
|
|
n_info.parameters.update(
|
|
|
|
{
|
|
|
|
k.name: MetaTensor(v)
|
|
|
|
for k, v in zip(n.args, args)
|
|
|
|
if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
|
|
|
|
}
|
|
|
|
)
|
2023-03-10 05:21:05 +00:00
|
|
|
n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)})
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + tuple(
|
|
|
|
v for v in kwargs.values() if is_pure_tensor(v)
|
|
|
|
)
|
2023-03-10 05:21:05 +00:00
|
|
|
|
2023-03-30 09:47:24 +00:00
|
|
|
# align with SPMD
|
|
|
|
if isinstance(r, (tuple, list)):
|
|
|
|
n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r))
|
|
|
|
else:
|
|
|
|
n._meta_data = unwrap_fn(r)
|
2023-03-10 05:21:05 +00:00
|
|
|
|
|
|
|
n_info.global_ctx = self.global_hook.ctx
|
|
|
|
n_info.curr_ctx = self.global_hook.ctx.copy()
|
|
|
|
|
|
|
|
crit = lambda x: x.data_ptr() in self.global_hook.ctx if isinstance(x, torch.Tensor) else False
|
|
|
|
n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs))
|
|
|
|
return r
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def call_function(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
|
2023-03-10 05:21:05 +00:00
|
|
|
"""
|
|
|
|
Execute a ``call_function`` node and return the result.
|
|
|
|
If the target of ``Node`` is registered with ``@register_shape_impl``,
|
|
|
|
the registered function will be used to execute the node. This is common
|
|
|
|
if we insert some customized kernels.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
target (Target): The call target for this node. See
|
|
|
|
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
|
|
|
details on semantics
|
|
|
|
args (Tuple): Tuple of positional args for this invocation
|
|
|
|
kwargs (Dict): Dict of keyword arguments for this invocation
|
|
|
|
|
|
|
|
Return
|
|
|
|
Any: The value returned by the function invocation
|
|
|
|
"""
|
2023-03-30 09:47:24 +00:00
|
|
|
convert_to_param = False
|
|
|
|
if target in (torch.transpose, torch.reshape) and isinstance(args[0], torch.nn.parameter.Parameter):
|
|
|
|
convert_to_param = True
|
2023-03-10 05:21:05 +00:00
|
|
|
if target in self._custom_dispatch_func:
|
2023-03-30 09:47:24 +00:00
|
|
|
res = self._custom_dispatch_func[target](*args, **kwargs)
|
|
|
|
else:
|
|
|
|
res = super().call_function(target, args, kwargs)
|
|
|
|
if convert_to_param:
|
|
|
|
return torch.nn.Parameter(res)
|
|
|
|
else:
|
|
|
|
return res
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def call_method(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
|
2023-03-30 09:47:24 +00:00
|
|
|
"""
|
|
|
|
Execute a ``call_method`` node and return the result.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
target (Target): The call target for this node. See
|
|
|
|
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
|
|
|
details on semantics
|
|
|
|
args (Tuple): Tuple of positional args for this invocation
|
|
|
|
kwargs (Dict): Dict of keyword arguments for this invocation
|
|
|
|
|
|
|
|
Return
|
|
|
|
Any: The value returned by the method invocation
|
|
|
|
"""
|
|
|
|
# args[0] is the `self` object for this method call
|
|
|
|
self_obj, *args_tail = args
|
|
|
|
|
|
|
|
target_method = getattr(self_obj.__class__, target)
|
|
|
|
|
|
|
|
convert_to_parameter = False
|
|
|
|
if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance(
|
2023-09-19 06:20:26 +00:00
|
|
|
args[0], torch.nn.parameter.Parameter
|
|
|
|
):
|
2023-03-30 09:47:24 +00:00
|
|
|
convert_to_parameter = True
|
|
|
|
# Execute the method and return the result
|
|
|
|
assert isinstance(target, str)
|
|
|
|
res = getattr(self_obj, target)(*args_tail, **kwargs)
|
|
|
|
if convert_to_parameter:
|
|
|
|
return torch.nn.Parameter(res)
|
2023-03-10 05:21:05 +00:00
|
|
|
else:
|
2023-03-30 09:47:24 +00:00
|
|
|
return res
|
2023-03-10 05:21:05 +00:00
|
|
|
|
|
|
|
def propagate(self, *args, device=None):
|
|
|
|
"""
|
|
|
|
Run `module` via interpretation and return the result and record the
|
|
|
|
shape of each node.
|
|
|
|
Args:
|
|
|
|
*args (Tensor): The sample input.
|
|
|
|
Returns:
|
|
|
|
Any: The value returned from executing the Module
|
|
|
|
"""
|
2023-04-04 09:40:45 +00:00
|
|
|
|
|
|
|
# wrap_fn = lambda elem: MetaTensor(elem, device=device)
|
|
|
|
def wrap_fn(elem, device=device):
|
|
|
|
if isinstance(elem, torch.Tensor):
|
|
|
|
return MetaTensor(elem, device=device)
|
|
|
|
else:
|
|
|
|
return elem
|
|
|
|
|
2023-03-10 05:21:05 +00:00
|
|
|
with self._mode:
|
|
|
|
return super().run(*tree_map(wrap_fn, args))
|
|
|
|
|
|
|
|
|
|
|
|
def shape_prop_pass(module: torch.fx.GraphModule, *args) -> torch.fx.GraphModule:
|
|
|
|
"""
|
|
|
|
Run ``module`` via interpretation and return the result and record the
|
|
|
|
shape of each ``Node``.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
module (GraphModule): The GraphModule to profile
|
|
|
|
*args (Any): The sample input
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
GraphModule: The same GraphModule with shape information
|
|
|
|
"""
|
|
|
|
|
|
|
|
ShapeProp(module).propagate(*args, device=_current_device(module))
|
|
|
|
return module
|