mirror of https://github.com/hpcaitech/ColossalAI
[fx] provide an accurate estimation of memory. (#1587)
* [fx] add some comment and docstrings. * [fx] add dataflow analysis for an autograd graph. * add intepretation for graph analysis. * [fx] before doing save_tensor_hooks. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] a very accurate version on GPT-2. * [fx] refactor code. * [fx] remove redundant inplace=True. * [fx] refactor code. * [fx] refactor code. * [fx] refactor code. * [fx] dive into backward memory.pull/1604/head
parent
27fe8af60c
commit
5c494d4540
|
@ -1,10 +1,12 @@
|
|||
from dataclasses import asdict
|
||||
from colossalai.fx.profiler import GraphInfo
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx.node import Node, Argument, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
from typing import Any, Tuple, NamedTuple, Dict
|
||||
from torch.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size, parameter_size
|
||||
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
|
@ -40,7 +42,7 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
|
|||
class MetaInfoProp(torch.fx.Interpreter):
|
||||
"""
|
||||
Execute an FX graph Node-by-Node with meta tensor and
|
||||
record the shape, FLOPs, MACs and type of the result
|
||||
record the memory usage, FLOPs, and type of the result
|
||||
into the corresponding node.
|
||||
|
||||
Usage:
|
||||
|
@ -82,7 +84,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
Returns:
|
||||
Any: The result of executing ``n``
|
||||
"""
|
||||
result, flop_count, mem_stat = super().run_node(n)
|
||||
result, meta_info = super().run_node(n)
|
||||
|
||||
def extract_tensor_meta(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
|
@ -90,21 +92,20 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
else:
|
||||
return TensorMetadata(None, None, False, None, 0, False)
|
||||
|
||||
meta = tree_map(extract_tensor_meta, result)
|
||||
n.meta['tensor_meta'] = meta
|
||||
tensor_meta = tree_map(extract_tensor_meta, result)
|
||||
n.meta['tensor_meta'] = tensor_meta
|
||||
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
|
||||
|
||||
# TODO: the attribute node_size should be removed in the future
|
||||
setattr(n, 'node_size', mem_stat[1])
|
||||
setattr(n, 'fwd_flop', flop_count[0])
|
||||
setattr(n, 'bwd_flop', flop_count[1])
|
||||
setattr(n, 'fwd_tmp', mem_stat[0])
|
||||
setattr(n, 'fwd_out', mem_stat[1])
|
||||
setattr(n, 'bwd_tmp', mem_stat[2])
|
||||
setattr(n, 'bwd_out', mem_stat[3])
|
||||
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
|
||||
for par in n.all_input_nodes:
|
||||
par.meta['fwd_mem_out'] = par.meta.get('fwd_mem_out', 0) + n.meta.get('fwd_mem_in', 0)
|
||||
n.meta['type'] = type(result)
|
||||
|
||||
# retain the autograd graph
|
||||
for param in self.module.parameters():
|
||||
param.grad = None
|
||||
|
||||
return result
|
||||
|
||||
# Main Node running APIs
|
||||
|
@ -125,12 +126,9 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
|
||||
Returns:
|
||||
result (Any): The argument value that was retrieved
|
||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||
"""
|
||||
result = super().placeholder(target, args, kwargs)
|
||||
# A placeholder node only has activation
|
||||
return result, (0, 0), (0, activation_size(result), 0, 0)
|
||||
return super().placeholder(target, args, kwargs), GraphInfo()
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
|
@ -147,10 +145,9 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
|
||||
Return:
|
||||
result (Any): The argument value that was retrieved
|
||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||
"""
|
||||
return super().get_attr(target, args, kwargs), (0, 0), (0, 0, 0, 0)
|
||||
return super().get_attr(target, args, kwargs), GraphInfo()
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
|
@ -166,8 +163,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
|
||||
Return
|
||||
result (Any): The argument value that was retrieved
|
||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||
"""
|
||||
assert not isinstance(target, str)
|
||||
return profile_function(target)(*args, **kwargs)
|
||||
|
@ -186,8 +182,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
|
||||
Return
|
||||
result (Any): The argument value that was retrieved
|
||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||
"""
|
||||
return profile_method(target)(*args, **kwargs)
|
||||
|
||||
|
@ -205,8 +200,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
|
||||
Return
|
||||
result (Any): The argument value that was retrieved
|
||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||
"""
|
||||
# Retrieve executed args and kwargs values from the environment
|
||||
# Execute the method and return the result
|
||||
|
@ -229,10 +223,9 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
|
||||
Return:
|
||||
result (Any): The argument value that was retrieved
|
||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||
"""
|
||||
return args[0], (0, 0), (0, 0, 0, 0)
|
||||
return args[0], GraphInfo(fwd_mem_in=activation_size(args[0]))
|
||||
|
||||
def propagate(self, *args):
|
||||
"""
|
||||
|
|
|
@ -2,8 +2,9 @@ from ... import META_COMPATIBILITY
|
|||
if META_COMPATIBILITY:
|
||||
from .opcount import flop_mapping
|
||||
from .tensor import MetaTensor
|
||||
from .profiler import profile_function, profile_method, profile_module, _profile
|
||||
from .profiler import profile_function, profile_method, profile_module
|
||||
else:
|
||||
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module
|
||||
|
||||
from .dataflow import GraphInfo
|
||||
from .memory import parameter_size, activation_size
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
from torch.fx import Graph, Node
|
||||
from .memory import activation_size
|
||||
|
||||
|
||||
class Stage(Enum):
|
||||
FORWARD = 0
|
||||
LOSS = 1
|
||||
BACKWARD = 2
|
||||
PLACEHOLDER = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphInfo:
|
||||
"""
|
||||
GraphInfo is a dataclass for MetaInfo, which measures
|
||||
the execution memory cost and FLOPs with `MetaTensor`.
|
||||
The dataflow analysis is conducted on a single node of the FX graph.
|
||||
============================================================================
|
||||
-------------------------------
|
||||
| Node |
|
||||
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`
|
||||
placeholders saved for | | \__________ | |
|
||||
backward. | | \ | |
|
||||
| [fwd_tmp] ------> [bwd_tmp] | <-----
|
||||
| | \_________ | | [bwd_tmp] marks the peak memory
|
||||
| / \ \ | | in backward pass.
|
||||
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
|
||||
in [fwd_tmp] because | | | \_____ | |
|
||||
it is not saved for | | | \ | |
|
||||
backward. -------------------------------
|
||||
============================================================================
|
||||
Attributes:
|
||||
fwd_flop (int): The forward FLOPs of a certain node
|
||||
bwd_flop (int): The backward FLOPs of a certain node.
|
||||
fwd_mem_in (int): See the above illustration.
|
||||
fwd_mem_tmp (int): See the above illustration.
|
||||
bwd_mem_tmp (int): See the above illustration.
|
||||
bwd_mem_out (int): See the above illustration.
|
||||
"""
|
||||
fwd_flop: int = 0
|
||||
bwd_flop: int = 0
|
||||
fwd_mem_in: int = 0
|
||||
fwd_mem_tmp: int = 0
|
||||
bwd_mem_tmp: int = 0
|
||||
bwd_mem_out: int = 0
|
||||
|
||||
|
||||
def is_forward(n: Node):
|
||||
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
|
||||
return n.meta['stage'] == Stage.FORWARD
|
||||
|
||||
|
||||
def is_loss(n: Node):
|
||||
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
|
||||
return n.meta['stage'] == Stage.LOSS
|
||||
|
||||
|
||||
def is_placeholder(n: Node):
|
||||
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
|
||||
return n.meta['stage'] == Stage.PLACEHOLDER
|
||||
|
||||
|
||||
def is_backward(n: Node):
|
||||
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
|
||||
return n.meta['stage'] == Stage.BACKWARD
|
||||
|
||||
|
||||
def is_saved(n: Node):
|
||||
return n.meta.get('saved', False)
|
||||
|
||||
|
||||
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
"""Analyze the autograd node dependencies and find out the memory usage.
|
||||
Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`.
|
||||
Nodes should have attribute `out` indicating the output of each node.
|
||||
============================================================================
|
||||
Placeholder ----> p o <---- We need to keep track of grad out
|
||||
|\________ |
|
||||
↓ ↘|
|
||||
f --------> b
|
||||
|\ \_____ ↑
|
||||
| \ ↘ /
|
||||
f f ----> b <---- Not every forward result needs to be saved for backward
|
||||
| \____ ↑
|
||||
↘ ↘|
|
||||
f ----> b <---- Backward can be freed as soon as it is required no more.
|
||||
↘ ↗
|
||||
l
|
||||
=============================================================================
|
||||
Args:
|
||||
graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`.
|
||||
|
||||
Returns:
|
||||
graph_info (GraphInfo): Meta information for the dataflow.
|
||||
"""
|
||||
|
||||
def _peak_memory(deps: Dict[Node, int]):
|
||||
bwd_tmp = 0
|
||||
for k, v in deps.items():
|
||||
if v > 0:
|
||||
bwd_tmp += activation_size(k.meta['out'])
|
||||
return bwd_tmp
|
||||
|
||||
# deps is used to track all the memory dependencies of the graph.
|
||||
deps = {}
|
||||
graph_info = GraphInfo()
|
||||
|
||||
for n in graph.nodes:
|
||||
n: Node
|
||||
if is_saved(n) and not any(map(is_loss, n.users)):
|
||||
# A forward tensor who is marked `save` but is not
|
||||
# an input to `loss` should be saved during forward.
|
||||
# If the tensor is a placeholder, then it belongs to `fwd_in`.
|
||||
# Any `fwd_in` should be kept in memory even this function
|
||||
# is checkpointed.
|
||||
# Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint
|
||||
# the node, `fwd_tmp` can be freed.
|
||||
if is_placeholder(n):
|
||||
graph_info.fwd_mem_in += activation_size(n.meta['out'])
|
||||
if is_forward(n):
|
||||
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
|
||||
elif is_backward(n):
|
||||
if len(n.users):
|
||||
# liveness analysis is only used in backward
|
||||
deps[n] = len(n.users)
|
||||
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
|
||||
for input_n in n.all_input_nodes:
|
||||
if input_n in deps:
|
||||
deps[input_n] -= 1
|
||||
else:
|
||||
# basically a backward node without user is a `grad_out` node
|
||||
graph_info.bwd_mem_out += activation_size(n.meta['out'])
|
||||
return graph_info
|
|
@ -1,3 +1,4 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Callable, Any, Dict, Tuple
|
||||
import torch
|
||||
from torch.fx.node import Argument, Target
|
||||
|
@ -6,6 +7,44 @@ from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLAC
|
|||
|
||||
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
||||
|
||||
|
||||
# this is for compatibility use
|
||||
@dataclass
|
||||
class GraphInfo:
|
||||
"""
|
||||
GraphInfo is a dataclass for MetaInfo, which measures
|
||||
the execution memory cost and FLOPs with `MetaTensor`.
|
||||
The dataflow analysis is conducted on a single node of the FX graph.
|
||||
============================================================================
|
||||
-------------------------------
|
||||
| Node |
|
||||
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`
|
||||
placeholders saved for | | \__________ | |
|
||||
backward. | | \ | |
|
||||
| [fwd_tmp] ------> [bwd_tmp] | <-----
|
||||
| | \_________ | | [bwd_tmp] marks the peak memory
|
||||
| / \ \ | | in backward pass.
|
||||
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
|
||||
in [fwd_tmp] because | | | \_____ | |
|
||||
it is not saved for | | | \ | |
|
||||
backward. -------------------------------
|
||||
============================================================================
|
||||
Attributes:
|
||||
fwd_flop (int): The forward FLOPs of a certain node
|
||||
bwd_flop (int): The backward FLOPs of a certain node.
|
||||
fwd_mem_in (int): See the above illustration.
|
||||
fwd_mem_tmp (int): See the above illustration.
|
||||
bwd_mem_tmp (int): See the above illustration.
|
||||
bwd_mem_out (int): See the above illustration.
|
||||
"""
|
||||
fwd_flop: int = 0
|
||||
bwd_flop: int = 0
|
||||
fwd_mem_in: int = 0
|
||||
fwd_mem_tmp: int = 0
|
||||
bwd_mem_tmp: int = 0
|
||||
bwd_mem_out: int = 0
|
||||
|
||||
|
||||
CALL_FUNCTION_MSG = \
|
||||
"""
|
||||
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
|
||||
|
@ -59,7 +98,7 @@ def profile_function(target: 'Target') -> Callable:
|
|||
else:
|
||||
profiler = meta_profiler_function.get(target.__name__)
|
||||
fwd_flop, _ = profiler(*args, **kwargs)
|
||||
return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
|
||||
return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
|
||||
|
||||
f.__name__ = target.__name__
|
||||
func = target
|
||||
|
@ -88,7 +127,7 @@ def profile_method(target: 'Target') -> Callable:
|
|||
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
|
||||
fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out)
|
||||
fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out)
|
||||
return out, (0, 0), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
|
||||
return out, GraphInfo(0, 0, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
|
||||
|
||||
return f
|
||||
|
||||
|
@ -118,7 +157,7 @@ def profile_module(module: torch.nn.Module) -> Callable:
|
|||
fwd_out = activation_size(out)
|
||||
profiler = meta_profiler_module.get(type(module))
|
||||
fwd_flop, _ = profiler(module, *args, **kwargs)
|
||||
return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
|
||||
return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
|
||||
|
||||
f.__name__ = module.__class__.__name__
|
||||
func = module.forward
|
||||
|
|
|
@ -14,12 +14,10 @@ if META_COMPATIBILITY:
|
|||
|
||||
INPLACE_ATEN = [
|
||||
aten.add_.Tensor,
|
||||
aten.add.Tensor,
|
||||
aten.sub_.Tensor,
|
||||
aten.div_.Tensor,
|
||||
aten.div_.Scalar,
|
||||
aten.mul_.Tensor,
|
||||
aten.mul.Tensor,
|
||||
aten.bernoulli_.float,
|
||||
|
||||
# inplace reshaping
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
from dataclasses import dataclass
|
||||
from enum import auto
|
||||
from typing import Callable, Any, Dict, Tuple
|
||||
import torch
|
||||
from torch.fx import Graph
|
||||
from torch.fx import Graph, Node
|
||||
from torch.fx.node import Argument, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
from .memory import activation_size, INPLACE_ATEN, WEIRD_OPS
|
||||
from .dataflow import autograd_graph_analysis, Stage
|
||||
from .memory import WEIRD_OPS
|
||||
from .tensor import MetaTensor
|
||||
from .opcount import flop_mapping
|
||||
|
||||
__all__ = ['profile_function', 'profile_module', 'profile_method', '_profile']
|
||||
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
||||
|
||||
|
||||
def normalize_tuple(x):
|
||||
|
@ -20,8 +23,9 @@ def is_autogradable(x):
|
|||
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
||||
|
||||
|
||||
def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
||||
"""Profile a Callable function with args and kwargs.
|
||||
def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...]:
|
||||
"""
|
||||
Profile a Callable function with args and kwargs.
|
||||
|
||||
Args:
|
||||
target (Callable): A Callable function
|
||||
|
@ -29,25 +33,32 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
|||
kwargs (Any): Argument
|
||||
|
||||
Returns:
|
||||
out (Tuple[Any, ...]): The argument value that was retrieved
|
||||
flop_count (Tuple[int, ...]): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple[int, ...]): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
out (Tuple[Any, ...]): The argument value that was retrieved.
|
||||
meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||
"""
|
||||
# This subgraph traces aten level ops inside one node.
|
||||
subgraph = Graph()
|
||||
|
||||
# `flop_count`` serves as a global dictionary to store results.
|
||||
flop_count = {
|
||||
'f': 0,
|
||||
'l': 0,
|
||||
'b': 0,
|
||||
Stage.FORWARD: 0,
|
||||
Stage.LOSS: 0,
|
||||
Stage.BACKWARD: 0,
|
||||
}
|
||||
temp = {
|
||||
'f': [],
|
||||
'l': [],
|
||||
'b': [],
|
||||
}
|
||||
stage = 'f'
|
||||
|
||||
# `stage` will mark the stage of autograd from outside scope.
|
||||
stage = Stage.FORWARD
|
||||
|
||||
# FlopTensor not only get the flop statistics of a single node,
|
||||
# it also build a full autograd graph for this node.
|
||||
# This makes sure we can analyze the dependencies of memory, and
|
||||
# decide which forward intermediate results should be kept until
|
||||
# backward is executed.
|
||||
# Hopefully, this attempt will provide a better estimation of memory.
|
||||
class FlopTensor(MetaTensor):
|
||||
|
||||
_node: Node
|
||||
|
||||
def __repr__(self):
|
||||
if self.grad_fn:
|
||||
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})"
|
||||
|
@ -56,66 +67,98 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
|||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
||||
def get_node(x):
|
||||
return None if not hasattr(x, '_node') else x._node
|
||||
|
||||
args_node = tree_map(get_node, args)
|
||||
kwargs_node = tree_map(get_node, kwargs)
|
||||
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
|
||||
|
||||
def unwrap(x):
|
||||
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor`
|
||||
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
||||
x = FlopTensor(x.to('meta'))
|
||||
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
|
||||
|
||||
def to_meta(x):
|
||||
return x.to('meta') if isinstance(x, torch.Tensor) else x
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
# run aten for backend=CPU but actually on backend=Meta
|
||||
out = func(*args, **kwargs)
|
||||
flop_count[stage] += flop_mapping[func](args, normalize_tuple(out))
|
||||
if func not in INPLACE_ATEN:
|
||||
temp[stage].append(tree_map(to_meta, normalize_tuple(out)))
|
||||
node.meta['out'] = normalize_tuple(out)
|
||||
node.meta['stage'] = stage
|
||||
|
||||
def wrap(x):
|
||||
return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x
|
||||
|
||||
return tree_map(wrap, out)
|
||||
def set_node(x):
|
||||
x._node = node
|
||||
|
||||
out = tree_map(wrap, out)
|
||||
tree_map(set_node, out)
|
||||
return out
|
||||
|
||||
# `WEIRD_OPS` are tough to handle because they don't accept autograd
|
||||
# on meta tensor.
|
||||
if target not in WEIRD_OPS:
|
||||
|
||||
def wrap(x):
|
||||
return FlopTensor(
|
||||
x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
|
||||
return FlopTensor(x.detach().requires_grad_(
|
||||
True)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x
|
||||
else:
|
||||
|
||||
def wrap(x):
|
||||
return FlopTensor(
|
||||
x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
|
||||
return FlopTensor(x.detach().requires_grad_(
|
||||
False)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x
|
||||
|
||||
# Basically, we need to detach the args and kwargs from the outer graph.
|
||||
args = tree_map(wrap, args)
|
||||
kwargs = tree_map(wrap, kwargs)
|
||||
|
||||
if isinstance(target, str):
|
||||
# args[0] is the `self` object for this method call
|
||||
self_obj, *args_tail = args
|
||||
out = getattr(self_obj, target)(*args_tail, **kwargs)
|
||||
else:
|
||||
out = target(*args, **kwargs)
|
||||
def set_placeholder(x):
|
||||
if isinstance(x, FlopTensor):
|
||||
x._node = subgraph.create_node('placeholder',
|
||||
'placeholder', (subgraph._root,),
|
||||
name=subgraph._graph_namespace.create_name('input', x._tensor))
|
||||
x._node.meta['stage'] = Stage.PLACEHOLDER
|
||||
x._node.meta['out'] = (x._tensor,)
|
||||
|
||||
tree_map(set_placeholder, args)
|
||||
tree_map(set_placeholder, kwargs)
|
||||
|
||||
def pack(x):
|
||||
if isinstance(x, FlopTensor):
|
||||
x._node.meta['saved'] = True
|
||||
return x
|
||||
|
||||
def unpack(x):
|
||||
return x
|
||||
|
||||
# mark saved tensors with saved_tensors_hooks
|
||||
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
||||
if isinstance(target, str):
|
||||
# args[0] is the `self` object for this method call
|
||||
self_obj, *args_tail = args
|
||||
out = getattr(self_obj, target)(*args_tail, **kwargs)
|
||||
else:
|
||||
out = target(*args, **kwargs)
|
||||
|
||||
# If the output is not a floating point `torch.Tensor` or it does not
|
||||
# requires grad, then we should not run backward for this node.
|
||||
if is_autogradable(out) and out.requires_grad:
|
||||
stage = 'l'
|
||||
stage = Stage.LOSS
|
||||
loss = out.sum()
|
||||
stage = 'b'
|
||||
stage = Stage.BACKWARD
|
||||
loss.backward()
|
||||
|
||||
fwd_flop = flop_count['f']
|
||||
bwd_flop = flop_count['b']
|
||||
|
||||
fwd_tmp = max(map(activation_size, temp['f'][:-1])) if len(temp['f'][:-1]) else 0
|
||||
fwd_out = activation_size(temp['f'][-1]) if len(temp['f']) else 0
|
||||
bwd_tmp = max(map(activation_size, temp['b'])) if len(temp['b']) else 0
|
||||
graph_info = autograd_graph_analysis(subgraph)
|
||||
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Stage.FORWARD], flop_count[Stage.BACKWARD]
|
||||
|
||||
def unwrap(x):
|
||||
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
|
||||
|
||||
return tree_map(unwrap, out), (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, 0)
|
||||
return tree_map(unwrap, out), graph_info
|
||||
|
||||
|
||||
def profile_function(target: 'Target') -> Callable:
|
||||
|
@ -130,17 +173,15 @@ def profile_function(target: 'Target') -> Callable:
|
|||
Examples:
|
||||
>>> input = torch.rand(100, 100, 100, 100, device='meta')
|
||||
>>> func = torch.nn.functional.relu
|
||||
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False)
|
||||
>>> output, meta_info = profile_function(func)(input, inplace=False)
|
||||
"""
|
||||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
if kwargs.get('inplace', False):
|
||||
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
|
||||
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
|
||||
out = func(*args, **kwargs)
|
||||
return out, (0, 0), (0, 0, 0, 0)
|
||||
out, flop_count, mem_stat = _profile(func, *args, **kwargs)
|
||||
return out, flop_count, mem_stat
|
||||
|
||||
# If there is an argument that this `call_function` is inplace, we should
|
||||
# skip the autograd profiling.
|
||||
out, meta = _profile(func, *args, **kwargs)
|
||||
return out, meta
|
||||
|
||||
f.__name__ = target.__name__
|
||||
func = target
|
||||
|
@ -156,8 +197,8 @@ def profile_method(target: 'Target') -> Callable:
|
|||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
# execute the method and return the result
|
||||
assert isinstance(target, str), f'{target} instance is not str.'
|
||||
out, flop_count, mem_stat = _profile(target, *args, **kwargs)
|
||||
return out, flop_count, mem_stat
|
||||
out, meta = _profile(target, *args, inplace=False, **kwargs)
|
||||
return out, meta
|
||||
|
||||
return f
|
||||
|
||||
|
@ -174,17 +215,15 @@ def profile_module(module: torch.nn.Module) -> Callable:
|
|||
Example:
|
||||
>>> input = torch.rand(4, 3, 224, 224, device='meta')
|
||||
>>> mod = torch.nn.Conv2d(3, 128, 3)
|
||||
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input)
|
||||
>>> output, meta_info = profile_module(mod)(input)
|
||||
"""
|
||||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
if getattr(module, 'inplace', False):
|
||||
args = tree_map(lambda x: x.to('meta'), args)
|
||||
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
|
||||
out = func(*args, **kwargs)
|
||||
return out, (out.numel(), out.numel()), (0, 0, 0, 0)
|
||||
out, flop_count, mem_stat = _profile(func, *args, **kwargs)
|
||||
return out, flop_count, mem_stat
|
||||
|
||||
# If there is an argument that this `call_module` is inplace, we should
|
||||
# skip the autograd profiling.
|
||||
out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs)
|
||||
return out, meta
|
||||
|
||||
f.__name__ = module.__class__.__name__
|
||||
func = module.forward
|
||||
|
|
Loading…
Reference in New Issue