|
|
@ -1,12 +1,10 @@
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
from enum import auto
|
|
|
|
|
|
|
|
from typing import Callable, Any, Dict, Tuple
|
|
|
|
from typing import Callable, Any, Dict, Tuple
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
from torch.fx import Graph, Node
|
|
|
|
from torch.fx import Graph, Node
|
|
|
|
from torch.fx.node import Argument, Target
|
|
|
|
from torch.fx.node import Argument, Target
|
|
|
|
from torch.utils._pytree import tree_map
|
|
|
|
from torch.utils._pytree import tree_map
|
|
|
|
from .dataflow import GraphInfo, autograd_graph_analysis, Phase
|
|
|
|
from .dataflow import GraphInfo, autograd_graph_analysis, Phase
|
|
|
|
from .memory import WEIRD_OPS, activation_size
|
|
|
|
from .memory import WEIRD_OPS
|
|
|
|
from .tensor import MetaTensor
|
|
|
|
from .tensor import MetaTensor
|
|
|
|
from .opcount import flop_mapping
|
|
|
|
from .opcount import flop_mapping
|
|
|
|
|
|
|
|
|
|
|
@ -23,7 +21,7 @@ def is_autogradable(x):
|
|
|
|
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
|
|
|
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...]:
|
|
|
|
def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Profile a Callable function with args and kwargs.
|
|
|
|
Profile a Callable function with args and kwargs.
|
|
|
|
|
|
|
|
|
|
|
@ -42,7 +40,6 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
|
|
|
# `flop_count`` serves as a global dictionary to store results.
|
|
|
|
# `flop_count`` serves as a global dictionary to store results.
|
|
|
|
flop_count = {
|
|
|
|
flop_count = {
|
|
|
|
Phase.FORWARD: 0,
|
|
|
|
Phase.FORWARD: 0,
|
|
|
|
Phase.LOSS: 0,
|
|
|
|
|
|
|
|
Phase.BACKWARD: 0,
|
|
|
|
Phase.BACKWARD: 0,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -71,6 +68,10 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
|
|
|
kwargs_node = tree_map(get_node, kwargs)
|
|
|
|
kwargs_node = tree_map(get_node, kwargs)
|
|
|
|
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
|
|
|
|
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# do not allocate on `cpu`
|
|
|
|
|
|
|
|
if 'device' in kwargs:
|
|
|
|
|
|
|
|
kwargs['device'] = 'meta'
|
|
|
|
|
|
|
|
|
|
|
|
def unwrap(x):
|
|
|
|
def unwrap(x):
|
|
|
|
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor`
|
|
|
|
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor`
|
|
|
|
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
|
|
|
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
|
|
@ -101,13 +102,13 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
|
|
|
if target not in WEIRD_OPS:
|
|
|
|
if target not in WEIRD_OPS:
|
|
|
|
|
|
|
|
|
|
|
|
def wrap(x):
|
|
|
|
def wrap(x):
|
|
|
|
return FlopTensor(x.detach().requires_grad_(
|
|
|
|
return FlopTensor(
|
|
|
|
True)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x
|
|
|
|
x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
def wrap(x):
|
|
|
|
def wrap(x):
|
|
|
|
return FlopTensor(x.detach().requires_grad_(
|
|
|
|
return FlopTensor(
|
|
|
|
False)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x
|
|
|
|
x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
|
|
|
|
|
|
|
|
|
|
|
|
# Basically, we need to detach the args and kwargs from the outer graph.
|
|
|
|
# Basically, we need to detach the args and kwargs from the outer graph.
|
|
|
|
args = tree_map(wrap, args)
|
|
|
|
args = tree_map(wrap, args)
|
|
|
@ -125,7 +126,7 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
|
|
|
tree_map(set_placeholder, kwargs)
|
|
|
|
tree_map(set_placeholder, kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def pack(x):
|
|
|
|
def pack(x):
|
|
|
|
if isinstance(x, FlopTensor):
|
|
|
|
if isinstance(x, FlopTensor) and not isinstance(x, torch.nn.Parameter):
|
|
|
|
x._node.meta['saved'] = True
|
|
|
|
x._node.meta['saved'] = True
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
@ -143,13 +144,15 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
out = target(*args, **kwargs)
|
|
|
|
out = target(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
# If the output is not a floating point `torch.Tensor` or it does not
|
|
|
|
# If the output is not a floating point `torch.Tensor` or it does not
|
|
|
|
# requires grad, then we should not run backward for this node.
|
|
|
|
# requires grad, then we should not run backward for this node.
|
|
|
|
if is_autogradable(out) and out.requires_grad:
|
|
|
|
if is_autogradable(out) and out.requires_grad:
|
|
|
|
phase = Phase.LOSS
|
|
|
|
phase = Phase.BACKWARD
|
|
|
|
loss = out.sum()
|
|
|
|
if isinstance(out, FlopTensor):
|
|
|
|
phase = Phase.BACKWARD
|
|
|
|
out._node.meta['save'] = False
|
|
|
|
loss.backward()
|
|
|
|
grad = torch.empty_like(out._tensor, device='meta') if isinstance(out, FlopTensor) else torch.empty_like(
|
|
|
|
|
|
|
|
out, device='meta')
|
|
|
|
|
|
|
|
torch.autograd.backward(out, FlopTensor(grad))
|
|
|
|
|
|
|
|
|
|
|
|
graph_info = autograd_graph_analysis(subgraph)
|
|
|
|
graph_info = autograd_graph_analysis(subgraph)
|
|
|
|
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
|
|
|
|
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
|
|
|
@ -172,7 +175,7 @@ def profile_function(target: 'Target') -> Callable:
|
|
|
|
Examples:
|
|
|
|
Examples:
|
|
|
|
>>> input = torch.rand(100, 100, 100, 100, device='meta')
|
|
|
|
>>> input = torch.rand(100, 100, 100, 100, device='meta')
|
|
|
|
>>> func = torch.nn.functional.relu
|
|
|
|
>>> func = torch.nn.functional.relu
|
|
|
|
>>> output, meta_info = profile_function(func)(input, inplace=False)
|
|
|
|
>>> output, meta_info = profile_function(func)(input)
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
|
|
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
|
|
@ -183,7 +186,7 @@ def profile_function(target: 'Target') -> Callable:
|
|
|
|
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
|
|
|
|
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
|
|
|
|
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
|
|
|
|
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
|
|
|
|
out = func(*args, **kwargs)
|
|
|
|
out = func(*args, **kwargs)
|
|
|
|
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0)
|
|
|
|
return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
|
|
|
|
out, meta = _profile(func, *args, **kwargs)
|
|
|
|
out, meta = _profile(func, *args, **kwargs)
|
|
|
|
return out, meta
|
|
|
|
return out, meta
|
|
|
|
|
|
|
|
|
|
|
@ -201,7 +204,7 @@ def profile_method(target: 'Target') -> Callable:
|
|
|
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
|
|
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
|
|
|
# execute the method and return the result
|
|
|
|
# execute the method and return the result
|
|
|
|
assert isinstance(target, str), f'{target} instance is not str.'
|
|
|
|
assert isinstance(target, str), f'{target} instance is not str.'
|
|
|
|
out, meta = _profile(target, *args, inplace=False, **kwargs)
|
|
|
|
out, meta = _profile(target, *args, **kwargs)
|
|
|
|
return out, meta
|
|
|
|
return out, meta
|
|
|
|
|
|
|
|
|
|
|
|
return f
|
|
|
|
return f
|
|
|
@ -230,8 +233,8 @@ def profile_module(module: torch.nn.Module) -> Callable:
|
|
|
|
args = tree_map(lambda x: x.to('meta'), args)
|
|
|
|
args = tree_map(lambda x: x.to('meta'), args)
|
|
|
|
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
|
|
|
|
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
|
|
|
|
out = func(*args, **kwargs)
|
|
|
|
out = func(*args, **kwargs)
|
|
|
|
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0)
|
|
|
|
return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
|
|
|
|
out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs)
|
|
|
|
out, meta = _profile(func, *args, **kwargs)
|
|
|
|
return out, meta
|
|
|
|
return out, meta
|
|
|
|
|
|
|
|
|
|
|
|
f.__name__ = module.__class__.__name__
|
|
|
|
f.__name__ = module.__class__.__name__
|
|
|
|