[fx] provide an accurate estimation of memory. (#1587)

* [fx] add some comment and docstrings.

* [fx] add dataflow analysis for an autograd graph.

* add intepretation for graph analysis.

* [fx] before doing save_tensor_hooks.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] a very accurate version on GPT-2.

* [fx] refactor code.

* [fx] remove redundant inplace=True.

* [fx] refactor code.

* [fx] refactor code.

* [fx] refactor code.

* [fx] dive into backward memory.
pull/1604/head
Super Daniel 2022-09-14 09:36:43 +08:00 committed by GitHub
parent 27fe8af60c
commit 5c494d4540
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 301 additions and 95 deletions

View File

@ -1,10 +1,12 @@
from dataclasses import asdict
from colossalai.fx.profiler import GraphInfo
import torch
import torch.fx
from torch.fx.node import Node, Argument, Target
from torch.utils._pytree import tree_map
from typing import Any, Tuple, NamedTuple, Dict
from torch.fx._compatibility import compatibility
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size, parameter_size
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
@compatibility(is_backward_compatible=True)
@ -40,7 +42,7 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
class MetaInfoProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node with meta tensor and
record the shape, FLOPs, MACs and type of the result
record the memory usage, FLOPs, and type of the result
into the corresponding node.
Usage:
@ -82,7 +84,7 @@ class MetaInfoProp(torch.fx.Interpreter):
Returns:
Any: The result of executing ``n``
"""
result, flop_count, mem_stat = super().run_node(n)
result, meta_info = super().run_node(n)
def extract_tensor_meta(obj):
if isinstance(obj, torch.Tensor):
@ -90,21 +92,20 @@ class MetaInfoProp(torch.fx.Interpreter):
else:
return TensorMetadata(None, None, False, None, 0, False)
meta = tree_map(extract_tensor_meta, result)
n.meta['tensor_meta'] = meta
tensor_meta = tree_map(extract_tensor_meta, result)
n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', mem_stat[1])
setattr(n, 'fwd_flop', flop_count[0])
setattr(n, 'bwd_flop', flop_count[1])
setattr(n, 'fwd_tmp', mem_stat[0])
setattr(n, 'fwd_out', mem_stat[1])
setattr(n, 'bwd_tmp', mem_stat[2])
setattr(n, 'bwd_out', mem_stat[3])
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
for par in n.all_input_nodes:
par.meta['fwd_mem_out'] = par.meta.get('fwd_mem_out', 0) + n.meta.get('fwd_mem_in', 0)
n.meta['type'] = type(result)
# retain the autograd graph
for param in self.module.parameters():
param.grad = None
return result
# Main Node running APIs
@ -125,12 +126,9 @@ class MetaInfoProp(torch.fx.Interpreter):
Returns:
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
result = super().placeholder(target, args, kwargs)
# A placeholder node only has activation
return result, (0, 0), (0, activation_size(result), 0, 0)
return super().placeholder(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
@ -147,10 +145,9 @@ class MetaInfoProp(torch.fx.Interpreter):
Return:
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return super().get_attr(target, args, kwargs), (0, 0), (0, 0, 0, 0)
return super().get_attr(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
@ -166,8 +163,7 @@ class MetaInfoProp(torch.fx.Interpreter):
Return
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
assert not isinstance(target, str)
return profile_function(target)(*args, **kwargs)
@ -186,8 +182,7 @@ class MetaInfoProp(torch.fx.Interpreter):
Return
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return profile_method(target)(*args, **kwargs)
@ -205,8 +200,7 @@ class MetaInfoProp(torch.fx.Interpreter):
Return
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
# Retrieve executed args and kwargs values from the environment
# Execute the method and return the result
@ -229,10 +223,9 @@ class MetaInfoProp(torch.fx.Interpreter):
Return:
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return args[0], (0, 0), (0, 0, 0, 0)
return args[0], GraphInfo(fwd_mem_in=activation_size(args[0]))
def propagate(self, *args):
"""

View File

@ -2,8 +2,9 @@ from ... import META_COMPATIBILITY
if META_COMPATIBILITY:
from .opcount import flop_mapping
from .tensor import MetaTensor
from .profiler import profile_function, profile_method, profile_module, _profile
from .profiler import profile_function, profile_method, profile_module
else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module
from .dataflow import GraphInfo
from .memory import parameter_size, activation_size

View File

@ -0,0 +1,136 @@
from dataclasses import dataclass
from enum import Enum
from typing import Dict
from torch.fx import Graph, Node
from .memory import activation_size
class Stage(Enum):
FORWARD = 0
LOSS = 1
BACKWARD = 2
PLACEHOLDER = 3
@dataclass
class GraphInfo:
"""
GraphInfo is a dataclass for MetaInfo, which measures
the execution memory cost and FLOPs with `MetaTensor`.
The dataflow analysis is conducted on a single node of the FX graph.
============================================================================
-------------------------------
| Node |
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`
placeholders saved for | | \__________ | |
backward. | | \ | |
| [fwd_tmp] ------> [bwd_tmp] | <-----
| | \_________ | | [bwd_tmp] marks the peak memory
| / \ \ | | in backward pass.
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
in [fwd_tmp] because | | | \_____ | |
it is not saved for | | | \ | |
backward. -------------------------------
============================================================================
Attributes:
fwd_flop (int): The forward FLOPs of a certain node
bwd_flop (int): The backward FLOPs of a certain node.
fwd_mem_in (int): See the above illustration.
fwd_mem_tmp (int): See the above illustration.
bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration.
"""
fwd_flop: int = 0
bwd_flop: int = 0
fwd_mem_in: int = 0
fwd_mem_tmp: int = 0
bwd_mem_tmp: int = 0
bwd_mem_out: int = 0
def is_forward(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.FORWARD
def is_loss(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.LOSS
def is_placeholder(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.PLACEHOLDER
def is_backward(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.BACKWARD
def is_saved(n: Node):
return n.meta.get('saved', False)
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
"""Analyze the autograd node dependencies and find out the memory usage.
Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`.
Nodes should have attribute `out` indicating the output of each node.
============================================================================
Placeholder ----> p o <---- We need to keep track of grad out
|\________ |
|
f --------> b
|\ \_____
| \ /
f f ----> b <---- Not every forward result needs to be saved for backward
| \____
|
f ----> b <---- Backward can be freed as soon as it is required no more.
l
=============================================================================
Args:
graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`.
Returns:
graph_info (GraphInfo): Meta information for the dataflow.
"""
def _peak_memory(deps: Dict[Node, int]):
bwd_tmp = 0
for k, v in deps.items():
if v > 0:
bwd_tmp += activation_size(k.meta['out'])
return bwd_tmp
# deps is used to track all the memory dependencies of the graph.
deps = {}
graph_info = GraphInfo()
for n in graph.nodes:
n: Node
if is_saved(n) and not any(map(is_loss, n.users)):
# A forward tensor who is marked `save` but is not
# an input to `loss` should be saved during forward.
# If the tensor is a placeholder, then it belongs to `fwd_in`.
# Any `fwd_in` should be kept in memory even this function
# is checkpointed.
# Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint
# the node, `fwd_tmp` can be freed.
if is_placeholder(n):
graph_info.fwd_mem_in += activation_size(n.meta['out'])
if is_forward(n):
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
elif is_backward(n):
if len(n.users):
# liveness analysis is only used in backward
deps[n] = len(n.users)
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
for input_n in n.all_input_nodes:
if input_n in deps:
deps[input_n] -= 1
else:
# basically a backward node without user is a `grad_out` node
graph_info.bwd_mem_out += activation_size(n.meta['out'])
return graph_info

View File

@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Callable, Any, Dict, Tuple
import torch
from torch.fx.node import Argument, Target
@ -6,6 +7,44 @@ from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLAC
__all__ = ['profile_function', 'profile_module', 'profile_method']
# this is for compatibility use
@dataclass
class GraphInfo:
"""
GraphInfo is a dataclass for MetaInfo, which measures
the execution memory cost and FLOPs with `MetaTensor`.
The dataflow analysis is conducted on a single node of the FX graph.
============================================================================
-------------------------------
| Node |
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`
placeholders saved for | | \__________ | |
backward. | | \ | |
| [fwd_tmp] ------> [bwd_tmp] | <-----
| | \_________ | | [bwd_tmp] marks the peak memory
| / \ \ | | in backward pass.
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
in [fwd_tmp] because | | | \_____ | |
it is not saved for | | | \ | |
backward. -------------------------------
============================================================================
Attributes:
fwd_flop (int): The forward FLOPs of a certain node
bwd_flop (int): The backward FLOPs of a certain node.
fwd_mem_in (int): See the above illustration.
fwd_mem_tmp (int): See the above illustration.
bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration.
"""
fwd_flop: int = 0
bwd_flop: int = 0
fwd_mem_in: int = 0
fwd_mem_tmp: int = 0
bwd_mem_tmp: int = 0
bwd_mem_out: int = 0
CALL_FUNCTION_MSG = \
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
@ -59,7 +98,7 @@ def profile_function(target: 'Target') -> Callable:
else:
profiler = meta_profiler_function.get(target.__name__)
fwd_flop, _ = profiler(*args, **kwargs)
return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
f.__name__ = target.__name__
func = target
@ -88,7 +127,7 @@ def profile_method(target: 'Target') -> Callable:
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out)
fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out)
return out, (0, 0), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
return out, GraphInfo(0, 0, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
return f
@ -118,7 +157,7 @@ def profile_module(module: torch.nn.Module) -> Callable:
fwd_out = activation_size(out)
profiler = meta_profiler_module.get(type(module))
fwd_flop, _ = profiler(module, *args, **kwargs)
return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
f.__name__ = module.__class__.__name__
func = module.forward

View File

@ -14,12 +14,10 @@ if META_COMPATIBILITY:
INPLACE_ATEN = [
aten.add_.Tensor,
aten.add.Tensor,
aten.sub_.Tensor,
aten.div_.Tensor,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.mul.Tensor,
aten.bernoulli_.float,
# inplace reshaping

View File

@ -1,13 +1,16 @@
from dataclasses import dataclass
from enum import auto
from typing import Callable, Any, Dict, Tuple
import torch
from torch.fx import Graph
from torch.fx import Graph, Node
from torch.fx.node import Argument, Target
from torch.utils._pytree import tree_map
from .memory import activation_size, INPLACE_ATEN, WEIRD_OPS
from .dataflow import autograd_graph_analysis, Stage
from .memory import WEIRD_OPS
from .tensor import MetaTensor
from .opcount import flop_mapping
__all__ = ['profile_function', 'profile_module', 'profile_method', '_profile']
__all__ = ['profile_function', 'profile_module', 'profile_method']
def normalize_tuple(x):
@ -20,8 +23,9 @@ def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point()
def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
"""Profile a Callable function with args and kwargs.
def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...]:
"""
Profile a Callable function with args and kwargs.
Args:
target (Callable): A Callable function
@ -29,25 +33,32 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
kwargs (Any): Argument
Returns:
out (Tuple[Any, ...]): The argument value that was retrieved
flop_count (Tuple[int, ...]): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple[int, ...]): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
out (Tuple[Any, ...]): The argument value that was retrieved.
meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
# This subgraph traces aten level ops inside one node.
subgraph = Graph()
# `flop_count`` serves as a global dictionary to store results.
flop_count = {
'f': 0,
'l': 0,
'b': 0,
Stage.FORWARD: 0,
Stage.LOSS: 0,
Stage.BACKWARD: 0,
}
temp = {
'f': [],
'l': [],
'b': [],
}
stage = 'f'
# `stage` will mark the stage of autograd from outside scope.
stage = Stage.FORWARD
# FlopTensor not only get the flop statistics of a single node,
# it also build a full autograd graph for this node.
# This makes sure we can analyze the dependencies of memory, and
# decide which forward intermediate results should be kept until
# backward is executed.
# Hopefully, this attempt will provide a better estimation of memory.
class FlopTensor(MetaTensor):
_node: Node
def __repr__(self):
if self.grad_fn:
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})"
@ -56,66 +67,98 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def get_node(x):
return None if not hasattr(x, '_node') else x._node
args_node = tree_map(get_node, args)
kwargs_node = tree_map(get_node, kwargs)
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
def unwrap(x):
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor`
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
x = FlopTensor(x.to('meta'))
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
def to_meta(x):
return x.to('meta') if isinstance(x, torch.Tensor) else x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
flop_count[stage] += flop_mapping[func](args, normalize_tuple(out))
if func not in INPLACE_ATEN:
temp[stage].append(tree_map(to_meta, normalize_tuple(out)))
node.meta['out'] = normalize_tuple(out)
node.meta['stage'] = stage
def wrap(x):
return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out)
def set_node(x):
x._node = node
out = tree_map(wrap, out)
tree_map(set_node, out)
return out
# `WEIRD_OPS` are tough to handle because they don't accept autograd
# on meta tensor.
if target not in WEIRD_OPS:
def wrap(x):
return FlopTensor(
x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
return FlopTensor(x.detach().requires_grad_(
True)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x
else:
def wrap(x):
return FlopTensor(
x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
return FlopTensor(x.detach().requires_grad_(
False)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x
# Basically, we need to detach the args and kwargs from the outer graph.
args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)
if isinstance(target, str):
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
out = getattr(self_obj, target)(*args_tail, **kwargs)
else:
out = target(*args, **kwargs)
def set_placeholder(x):
if isinstance(x, FlopTensor):
x._node = subgraph.create_node('placeholder',
'placeholder', (subgraph._root,),
name=subgraph._graph_namespace.create_name('input', x._tensor))
x._node.meta['stage'] = Stage.PLACEHOLDER
x._node.meta['out'] = (x._tensor,)
tree_map(set_placeholder, args)
tree_map(set_placeholder, kwargs)
def pack(x):
if isinstance(x, FlopTensor):
x._node.meta['saved'] = True
return x
def unpack(x):
return x
# mark saved tensors with saved_tensors_hooks
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
if isinstance(target, str):
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
out = getattr(self_obj, target)(*args_tail, **kwargs)
else:
out = target(*args, **kwargs)
# If the output is not a floating point `torch.Tensor` or it does not
# requires grad, then we should not run backward for this node.
if is_autogradable(out) and out.requires_grad:
stage = 'l'
stage = Stage.LOSS
loss = out.sum()
stage = 'b'
stage = Stage.BACKWARD
loss.backward()
fwd_flop = flop_count['f']
bwd_flop = flop_count['b']
fwd_tmp = max(map(activation_size, temp['f'][:-1])) if len(temp['f'][:-1]) else 0
fwd_out = activation_size(temp['f'][-1]) if len(temp['f']) else 0
bwd_tmp = max(map(activation_size, temp['b'])) if len(temp['b']) else 0
graph_info = autograd_graph_analysis(subgraph)
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Stage.FORWARD], flop_count[Stage.BACKWARD]
def unwrap(x):
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
return tree_map(unwrap, out), (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, 0)
return tree_map(unwrap, out), graph_info
def profile_function(target: 'Target') -> Callable:
@ -130,17 +173,15 @@ def profile_function(target: 'Target') -> Callable:
Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False)
>>> output, meta_info = profile_function(func)(input, inplace=False)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
if kwargs.get('inplace', False):
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
out = func(*args, **kwargs)
return out, (0, 0), (0, 0, 0, 0)
out, flop_count, mem_stat = _profile(func, *args, **kwargs)
return out, flop_count, mem_stat
# If there is an argument that this `call_function` is inplace, we should
# skip the autograd profiling.
out, meta = _profile(func, *args, **kwargs)
return out, meta
f.__name__ = target.__name__
func = target
@ -156,8 +197,8 @@ def profile_method(target: 'Target') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.'
out, flop_count, mem_stat = _profile(target, *args, **kwargs)
return out, flop_count, mem_stat
out, meta = _profile(target, *args, inplace=False, **kwargs)
return out, meta
return f
@ -174,17 +215,15 @@ def profile_module(module: torch.nn.Module) -> Callable:
Example:
>>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3)
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input)
>>> output, meta_info = profile_module(mod)(input)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
if getattr(module, 'inplace', False):
args = tree_map(lambda x: x.to('meta'), args)
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
out = func(*args, **kwargs)
return out, (out.numel(), out.numel()), (0, 0, 0, 0)
out, flop_count, mem_stat = _profile(func, *args, **kwargs)
return out, flop_count, mem_stat
# If there is an argument that this `call_module` is inplace, we should
# skip the autograd profiling.
out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs)
return out, meta
f.__name__ = module.__class__.__name__
func = module.forward