2022-09-07 03:21:04 +00:00
|
|
|
from typing import Callable, Any, Dict, Tuple
|
2022-08-24 08:22:44 +00:00
|
|
|
import torch
|
2022-09-14 01:36:43 +00:00
|
|
|
from torch.fx import Graph, Node
|
2022-08-31 08:30:16 +00:00
|
|
|
from torch.fx.node import Argument, Target
|
2022-09-07 03:21:04 +00:00
|
|
|
from torch.utils._pytree import tree_map
|
2022-09-14 06:27:04 +00:00
|
|
|
from .dataflow import GraphInfo, autograd_graph_analysis, Phase
|
2022-09-15 06:46:36 +00:00
|
|
|
from .memory import WEIRD_OPS
|
2022-09-07 03:21:04 +00:00
|
|
|
from .tensor import MetaTensor
|
|
|
|
from .opcount import flop_mapping
|
2022-08-25 15:11:13 +00:00
|
|
|
|
2022-09-14 01:36:43 +00:00
|
|
|
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
2022-09-07 03:21:04 +00:00
|
|
|
|
|
|
|
|
|
|
|
def normalize_tuple(x):
|
|
|
|
if not isinstance(x, tuple):
|
|
|
|
return (x,)
|
|
|
|
return x
|
2022-08-25 15:11:13 +00:00
|
|
|
|
2022-08-24 08:22:44 +00:00
|
|
|
|
2022-09-07 03:21:04 +00:00
|
|
|
def is_autogradable(x):
|
|
|
|
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
2022-08-24 08:22:44 +00:00
|
|
|
|
2022-09-07 03:21:04 +00:00
|
|
|
|
2022-09-15 06:46:36 +00:00
|
|
|
def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
2022-09-14 01:36:43 +00:00
|
|
|
"""
|
|
|
|
Profile a Callable function with args and kwargs.
|
2022-08-25 15:11:13 +00:00
|
|
|
|
|
|
|
Args:
|
2022-09-07 03:21:04 +00:00
|
|
|
target (Callable): A Callable function
|
|
|
|
args (Any): Argument
|
|
|
|
kwargs (Any): Argument
|
2022-08-25 15:11:13 +00:00
|
|
|
|
|
|
|
Returns:
|
2022-09-14 01:36:43 +00:00
|
|
|
out (Tuple[Any, ...]): The argument value that was retrieved.
|
|
|
|
meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
2022-08-24 08:22:44 +00:00
|
|
|
"""
|
2022-09-14 01:36:43 +00:00
|
|
|
# This subgraph traces aten level ops inside one node.
|
|
|
|
subgraph = Graph()
|
2022-09-07 03:21:04 +00:00
|
|
|
|
2022-09-14 01:36:43 +00:00
|
|
|
# `flop_count`` serves as a global dictionary to store results.
|
2022-09-07 03:21:04 +00:00
|
|
|
flop_count = {
|
2022-09-14 06:27:04 +00:00
|
|
|
Phase.FORWARD: 0,
|
|
|
|
Phase.BACKWARD: 0,
|
2022-09-07 03:21:04 +00:00
|
|
|
}
|
|
|
|
|
2022-09-14 01:36:43 +00:00
|
|
|
# FlopTensor not only get the flop statistics of a single node,
|
|
|
|
# it also build a full autograd graph for this node.
|
|
|
|
# This makes sure we can analyze the dependencies of memory, and
|
|
|
|
# decide which forward intermediate results should be kept until
|
|
|
|
# backward is executed.
|
|
|
|
# Hopefully, this attempt will provide a better estimation of memory.
|
2022-09-07 03:21:04 +00:00
|
|
|
class FlopTensor(MetaTensor):
|
|
|
|
|
2022-09-14 01:36:43 +00:00
|
|
|
_node: Node
|
|
|
|
|
2022-09-07 03:21:04 +00:00
|
|
|
def __repr__(self):
|
|
|
|
if self.grad_fn:
|
|
|
|
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})"
|
|
|
|
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)})"
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
2022-09-14 01:36:43 +00:00
|
|
|
def get_node(x):
|
|
|
|
return None if not hasattr(x, '_node') else x._node
|
|
|
|
|
|
|
|
args_node = tree_map(get_node, args)
|
|
|
|
kwargs_node = tree_map(get_node, kwargs)
|
|
|
|
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
|
|
|
|
|
2022-09-15 06:46:36 +00:00
|
|
|
# do not allocate on `cpu`
|
|
|
|
if 'device' in kwargs:
|
|
|
|
kwargs['device'] = 'meta'
|
|
|
|
|
2022-09-07 03:21:04 +00:00
|
|
|
def unwrap(x):
|
2022-09-14 01:36:43 +00:00
|
|
|
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor`
|
2022-09-07 03:21:04 +00:00
|
|
|
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
|
|
|
x = FlopTensor(x.to('meta'))
|
|
|
|
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
|
|
|
|
|
|
|
|
args = tree_map(unwrap, args)
|
|
|
|
kwargs = tree_map(unwrap, kwargs)
|
|
|
|
|
|
|
|
# run aten for backend=CPU but actually on backend=Meta
|
|
|
|
out = func(*args, **kwargs)
|
2022-09-14 06:27:04 +00:00
|
|
|
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
|
2022-09-14 01:36:43 +00:00
|
|
|
node.meta['out'] = normalize_tuple(out)
|
2022-09-14 06:27:04 +00:00
|
|
|
node.meta['phase'] = phase
|
2022-09-07 03:21:04 +00:00
|
|
|
|
|
|
|
def wrap(x):
|
|
|
|
return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x
|
|
|
|
|
2022-09-14 01:36:43 +00:00
|
|
|
def set_node(x):
|
|
|
|
x._node = node
|
|
|
|
|
|
|
|
out = tree_map(wrap, out)
|
|
|
|
tree_map(set_node, out)
|
|
|
|
return out
|
2022-09-07 03:21:04 +00:00
|
|
|
|
2022-09-14 01:36:43 +00:00
|
|
|
# `WEIRD_OPS` are tough to handle because they don't accept autograd
|
|
|
|
# on meta tensor.
|
2022-09-07 03:21:04 +00:00
|
|
|
if target not in WEIRD_OPS:
|
|
|
|
|
|
|
|
def wrap(x):
|
2022-09-15 06:46:36 +00:00
|
|
|
return FlopTensor(
|
|
|
|
x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
|
2022-09-07 03:21:04 +00:00
|
|
|
else:
|
|
|
|
|
|
|
|
def wrap(x):
|
2022-09-15 06:46:36 +00:00
|
|
|
return FlopTensor(
|
|
|
|
x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
|
2022-09-07 03:21:04 +00:00
|
|
|
|
2022-09-14 01:36:43 +00:00
|
|
|
# Basically, we need to detach the args and kwargs from the outer graph.
|
2022-09-07 03:21:04 +00:00
|
|
|
args = tree_map(wrap, args)
|
|
|
|
kwargs = tree_map(wrap, kwargs)
|
|
|
|
|
2022-09-14 01:36:43 +00:00
|
|
|
def set_placeholder(x):
|
|
|
|
if isinstance(x, FlopTensor):
|
|
|
|
x._node = subgraph.create_node('placeholder',
|
|
|
|
'placeholder', (subgraph._root,),
|
|
|
|
name=subgraph._graph_namespace.create_name('input', x._tensor))
|
2022-09-14 06:27:04 +00:00
|
|
|
x._node.meta['phase'] = Phase.PLACEHOLDER
|
2022-09-14 01:36:43 +00:00
|
|
|
x._node.meta['out'] = (x._tensor,)
|
|
|
|
|
|
|
|
tree_map(set_placeholder, args)
|
|
|
|
tree_map(set_placeholder, kwargs)
|
|
|
|
|
|
|
|
def pack(x):
|
2022-09-15 06:46:36 +00:00
|
|
|
if isinstance(x, FlopTensor) and not isinstance(x, torch.nn.Parameter):
|
2022-09-14 01:36:43 +00:00
|
|
|
x._node.meta['saved'] = True
|
|
|
|
return x
|
|
|
|
|
|
|
|
def unpack(x):
|
|
|
|
return x
|
|
|
|
|
2022-09-14 06:27:04 +00:00
|
|
|
# `phase` will mark the phase of autograd from outside scope.
|
|
|
|
phase = Phase.FORWARD
|
2022-09-14 01:36:43 +00:00
|
|
|
# mark saved tensors with saved_tensors_hooks
|
|
|
|
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
|
|
|
if isinstance(target, str):
|
|
|
|
# args[0] is the `self` object for this method call
|
|
|
|
self_obj, *args_tail = args
|
|
|
|
out = getattr(self_obj, target)(*args_tail, **kwargs)
|
|
|
|
else:
|
|
|
|
out = target(*args, **kwargs)
|
|
|
|
|
2022-09-15 06:46:36 +00:00
|
|
|
# If the output is not a floating point `torch.Tensor` or it does not
|
|
|
|
# requires grad, then we should not run backward for this node.
|
|
|
|
if is_autogradable(out) and out.requires_grad:
|
|
|
|
phase = Phase.BACKWARD
|
|
|
|
if isinstance(out, FlopTensor):
|
|
|
|
out._node.meta['save'] = False
|
|
|
|
grad = torch.empty_like(out._tensor, device='meta') if isinstance(out, FlopTensor) else torch.empty_like(
|
|
|
|
out, device='meta')
|
|
|
|
torch.autograd.backward(out, FlopTensor(grad))
|
2022-09-07 03:21:04 +00:00
|
|
|
|
2022-09-14 01:36:43 +00:00
|
|
|
graph_info = autograd_graph_analysis(subgraph)
|
2022-09-14 06:27:04 +00:00
|
|
|
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
|
2022-09-07 03:21:04 +00:00
|
|
|
|
|
|
|
def unwrap(x):
|
|
|
|
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
|
|
|
|
|
2022-09-14 01:36:43 +00:00
|
|
|
return tree_map(unwrap, out), graph_info
|
2022-08-24 08:22:44 +00:00
|
|
|
|
|
|
|
|
|
|
|
def profile_function(target: 'Target') -> Callable:
|
|
|
|
"""
|
|
|
|
Wrap a `call_function` node or `torch.nn.functional` in order to
|
|
|
|
record the memory cost and FLOPs of the execution.
|
|
|
|
|
|
|
|
Warnings:
|
|
|
|
You may only use tensors with `device=meta` for this wrapped function.
|
|
|
|
Only original `torch.nn.functional` are available.
|
|
|
|
|
2022-08-25 15:11:13 +00:00
|
|
|
Examples:
|
2022-09-07 03:21:04 +00:00
|
|
|
>>> input = torch.rand(100, 100, 100, 100, device='meta')
|
|
|
|
>>> func = torch.nn.functional.relu
|
2022-09-15 06:46:36 +00:00
|
|
|
>>> output, meta_info = profile_function(func)(input)
|
2022-08-24 08:22:44 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
2022-09-14 01:36:43 +00:00
|
|
|
|
|
|
|
# If there is an argument that this `call_function` is inplace, we should
|
|
|
|
# skip the autograd profiling.
|
2022-09-14 06:27:04 +00:00
|
|
|
if kwargs.get('inplace', False):
|
|
|
|
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
|
|
|
|
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
|
|
|
|
out = func(*args, **kwargs)
|
2022-09-15 06:46:36 +00:00
|
|
|
return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
|
2022-09-14 01:36:43 +00:00
|
|
|
out, meta = _profile(func, *args, **kwargs)
|
|
|
|
return out, meta
|
2022-08-24 08:22:44 +00:00
|
|
|
|
|
|
|
f.__name__ = target.__name__
|
2022-08-31 08:30:16 +00:00
|
|
|
func = target
|
2022-08-24 08:22:44 +00:00
|
|
|
return f
|
|
|
|
|
|
|
|
|
|
|
|
def profile_method(target: 'Target') -> Callable:
|
|
|
|
"""
|
|
|
|
Wrap a `call_method` node
|
|
|
|
record the memory cost and FLOPs of the execution.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
2022-08-25 15:11:13 +00:00
|
|
|
# execute the method and return the result
|
2022-08-24 08:22:44 +00:00
|
|
|
assert isinstance(target, str), f'{target} instance is not str.'
|
2022-09-15 06:46:36 +00:00
|
|
|
out, meta = _profile(target, *args, **kwargs)
|
2022-09-14 01:36:43 +00:00
|
|
|
return out, meta
|
2022-08-24 08:22:44 +00:00
|
|
|
|
|
|
|
return f
|
|
|
|
|
|
|
|
|
|
|
|
def profile_module(module: torch.nn.Module) -> Callable:
|
|
|
|
"""
|
|
|
|
Wrap a `call_module` node or `torch.nn` in order to
|
|
|
|
record the memory cost and FLOPs of the execution.
|
|
|
|
|
|
|
|
Warnings:
|
|
|
|
You may only use tensors with `device=meta` for this wrapped function.
|
|
|
|
Only original `torch.nn` are available.
|
|
|
|
|
2022-08-25 15:11:13 +00:00
|
|
|
Example:
|
2022-09-07 03:21:04 +00:00
|
|
|
>>> input = torch.rand(4, 3, 224, 224, device='meta')
|
|
|
|
>>> mod = torch.nn.Conv2d(3, 128, 3)
|
2022-09-14 01:36:43 +00:00
|
|
|
>>> output, meta_info = profile_module(mod)(input)
|
2022-08-24 08:22:44 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
2022-09-14 01:36:43 +00:00
|
|
|
|
|
|
|
# If there is an argument that this `call_module` is inplace, we should
|
|
|
|
# skip the autograd profiling.
|
2022-09-14 06:27:04 +00:00
|
|
|
if getattr(module, 'inplace', False):
|
|
|
|
args = tree_map(lambda x: x.to('meta'), args)
|
|
|
|
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
|
|
|
|
out = func(*args, **kwargs)
|
2022-09-15 06:46:36 +00:00
|
|
|
return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
|
|
|
|
out, meta = _profile(func, *args, **kwargs)
|
2022-09-14 01:36:43 +00:00
|
|
|
return out, meta
|
2022-08-24 08:22:44 +00:00
|
|
|
|
|
|
|
f.__name__ = module.__class__.__name__
|
2022-08-31 08:30:16 +00:00
|
|
|
func = module.forward
|
2022-08-24 08:22:44 +00:00
|
|
|
return f
|