mirror of https://github.com/hpcaitech/ColossalAI
195 lines
7.2 KiB
Python
195 lines
7.2 KiB
Python
"""``torch.fx.ShapeProp``, but with ``MetaTensor``"""
|
|
|
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.fx
|
|
from torch.autograd.graph import saved_tensors_hooks
|
|
from torch.utils._pytree import tree_map
|
|
|
|
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
|
|
from colossalai._analyzer.fx.node_util import MetaInfo
|
|
from colossalai.fx._compatibility import compatibility
|
|
|
|
Target = Union[Callable[..., Any], str]
|
|
|
|
|
|
class sim_env(saved_tensors_hooks):
|
|
"""
|
|
A simulation of memory allocation and deallocation in the forward pass
|
|
using ``saved_tensor_hooks``.
|
|
|
|
Attributes:
|
|
ctx (Dict[int, torch.Tensor]): A dictionary that maps the
|
|
data pointer of a tensor to the tensor itself. This is used
|
|
to track the memory allocation and deallocation.
|
|
|
|
param_ctx (Dict[int, torch.Tensor]): A dictionary that maps the
|
|
data pointer of all model parameters to the parameter itself.
|
|
This avoids overestimating the memory usage of the intermediate activations.
|
|
"""
|
|
|
|
def __init__(self, module: Optional[torch.nn.Module] = None):
|
|
super().__init__(self.pack_hook, self.unpack_hook)
|
|
self.ctx = {}
|
|
self.param_ctx = {param.data_ptr(): param for param in module.parameters()}
|
|
self.buffer_ctx = {buffer.data_ptr(): buffer for buffer in module.buffers()} if module else {}
|
|
|
|
def pack_hook(self, tensor: torch.Tensor):
|
|
if tensor.data_ptr() not in self.param_ctx and tensor.data_ptr() not in self.buffer_ctx:
|
|
self.ctx[tensor.data_ptr()] = tensor
|
|
return tensor
|
|
|
|
def unpack_hook(self, tensor):
|
|
return tensor
|
|
|
|
|
|
def _normalize_tuple(x):
|
|
if not isinstance(x, tuple):
|
|
return (x,)
|
|
return x
|
|
|
|
|
|
def _current_device(module):
|
|
return next(module.parameters()).device
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
class ShapeProp(torch.fx.Interpreter):
|
|
"""
|
|
Execute an FX graph Node-by-Node and record the meta data of the result
|
|
into the corresponding node.
|
|
|
|
Usage:
|
|
>>> model = MyModule()
|
|
>>> x = torch.rand(10, 10)
|
|
>>> gm = colossalai.fx.symbolic_trace(model, meta_args = {'x': x})
|
|
>>> interp = ShapeProp(gm)
|
|
>>> interp.propagate(x)
|
|
|
|
Args:
|
|
module (GraphModule): The module to be executed
|
|
|
|
Hints:
|
|
If you want to add a new shape propagation rule, you can do so by
|
|
adding a new method to this class with the ``@register_shape_impl``
|
|
decorator. The method should take (*args, **kwargs) instance as its
|
|
input and generate output.
|
|
|
|
For example, if you want to add a shape propagation rule for
|
|
``torch.nn.functional.linear``, you can do so by adding a new method
|
|
to this class with the ``@register_shape_impl`` decorator (Since the
|
|
``MetaTensorMode`` is compatible with ``torch.nn.functional.linear``,
|
|
in practice you don't have to do as follows):
|
|
|
|
>>> @register_shape_impl(torch.nn.functional.linear)
|
|
>>> def linear_shape_impl(*args, **kwargs):
|
|
>>> # do something here
|
|
>>> return torch.empty(output_shape, device=output_device)
|
|
"""
|
|
_custom_dispatch_func = {}
|
|
_mode = MetaTensorMode()
|
|
|
|
def __init__(self, module: torch.fx.GraphModule, garbage_collect_values: bool = True):
|
|
super().__init__(module, garbage_collect_values)
|
|
self.global_hook = sim_env(module=self.module)
|
|
|
|
def run_node(self, n: torch.fx.Node) -> Any:
|
|
"""
|
|
Run a specific node ``n`` and return the result. Attach
|
|
(
|
|
``inputs``, ``outputs``, ``parameters``, ``buffers``
|
|
) to ``n``.
|
|
|
|
Args:
|
|
n (Node): The ``Node`` to execute
|
|
|
|
Returns:
|
|
Any: The result of executing ``n``
|
|
"""
|
|
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
|
with self.global_hook:
|
|
r = getattr(self, n.op)(n.target, args, kwargs)
|
|
|
|
unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
|
|
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
|
|
n_info = MetaInfo(n)
|
|
n_info.outputs = _normalize_tuple(r)
|
|
|
|
if n.op == 'call_module':
|
|
submod = self.fetch_attr(n.target)
|
|
n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()})
|
|
n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()})
|
|
|
|
else:
|
|
n_info.parameters.update({
|
|
k.name: MetaTensor(v)
|
|
for k, v in zip(n.args, args)
|
|
if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
|
|
})
|
|
n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)})
|
|
|
|
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
|
|
tuple(v for v in kwargs.values() if is_pure_tensor(v))
|
|
|
|
n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r)) # align with SPMD
|
|
|
|
n_info.global_ctx = self.global_hook.ctx
|
|
n_info.curr_ctx = self.global_hook.ctx.copy()
|
|
|
|
crit = lambda x: x.data_ptr() in self.global_hook.ctx if isinstance(x, torch.Tensor) else False
|
|
n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs))
|
|
return r
|
|
|
|
def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
|
|
"""
|
|
Execute a ``call_function`` node and return the result.
|
|
If the target of ``Node`` is registered with ``@register_shape_impl``,
|
|
the registered function will be used to execute the node. This is common
|
|
if we insert some customized kernels.
|
|
|
|
Args:
|
|
target (Target): The call target for this node. See
|
|
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
|
details on semantics
|
|
args (Tuple): Tuple of positional args for this invocation
|
|
kwargs (Dict): Dict of keyword arguments for this invocation
|
|
|
|
Return
|
|
Any: The value returned by the function invocation
|
|
"""
|
|
if target in self._custom_dispatch_func:
|
|
return self._custom_dispatch_func[target](*args, **kwargs)
|
|
else:
|
|
return super().call_function(target, args, kwargs)
|
|
|
|
def propagate(self, *args, device=None):
|
|
"""
|
|
Run `module` via interpretation and return the result and record the
|
|
shape of each node.
|
|
Args:
|
|
*args (Tensor): The sample input.
|
|
Returns:
|
|
Any: The value returned from executing the Module
|
|
"""
|
|
wrap_fn = lambda elem: MetaTensor(elem, device=device)
|
|
with self._mode:
|
|
return super().run(*tree_map(wrap_fn, args))
|
|
|
|
|
|
def shape_prop_pass(module: torch.fx.GraphModule, *args) -> torch.fx.GraphModule:
|
|
"""
|
|
Run ``module`` via interpretation and return the result and record the
|
|
shape of each ``Node``.
|
|
|
|
Args:
|
|
module (GraphModule): The GraphModule to profile
|
|
*args (Any): The sample input
|
|
|
|
Returns:
|
|
GraphModule: The same GraphModule with shape information
|
|
"""
|
|
|
|
ShapeProp(module).propagate(*args, device=_current_device(module))
|
|
return module
|