mirror of https://github.com/hpcaitech/ColossalAI
[fx] provide an accurate estimation of memory. (#1587)
* [fx] add some comment and docstrings. * [fx] add dataflow analysis for an autograd graph. * add intepretation for graph analysis. * [fx] before doing save_tensor_hooks. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] a very accurate version on GPT-2. * [fx] refactor code. * [fx] remove redundant inplace=True. * [fx] refactor code. * [fx] refactor code. * [fx] refactor code. * [fx] dive into backward memory.pull/1604/head
parent
27fe8af60c
commit
5c494d4540
|
@ -1,10 +1,12 @@
|
||||||
|
from dataclasses import asdict
|
||||||
|
from colossalai.fx.profiler import GraphInfo
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch.fx.node import Node, Argument, Target
|
from torch.fx.node import Node, Argument, Target
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
from typing import Any, Tuple, NamedTuple, Dict
|
from typing import Any, Tuple, NamedTuple, Dict
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size, parameter_size
|
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
|
@ -40,7 +42,7 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
|
||||||
class MetaInfoProp(torch.fx.Interpreter):
|
class MetaInfoProp(torch.fx.Interpreter):
|
||||||
"""
|
"""
|
||||||
Execute an FX graph Node-by-Node with meta tensor and
|
Execute an FX graph Node-by-Node with meta tensor and
|
||||||
record the shape, FLOPs, MACs and type of the result
|
record the memory usage, FLOPs, and type of the result
|
||||||
into the corresponding node.
|
into the corresponding node.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
|
@ -82,7 +84,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
Returns:
|
Returns:
|
||||||
Any: The result of executing ``n``
|
Any: The result of executing ``n``
|
||||||
"""
|
"""
|
||||||
result, flop_count, mem_stat = super().run_node(n)
|
result, meta_info = super().run_node(n)
|
||||||
|
|
||||||
def extract_tensor_meta(obj):
|
def extract_tensor_meta(obj):
|
||||||
if isinstance(obj, torch.Tensor):
|
if isinstance(obj, torch.Tensor):
|
||||||
|
@ -90,21 +92,20 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
else:
|
else:
|
||||||
return TensorMetadata(None, None, False, None, 0, False)
|
return TensorMetadata(None, None, False, None, 0, False)
|
||||||
|
|
||||||
meta = tree_map(extract_tensor_meta, result)
|
tensor_meta = tree_map(extract_tensor_meta, result)
|
||||||
n.meta['tensor_meta'] = meta
|
n.meta['tensor_meta'] = tensor_meta
|
||||||
|
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
|
||||||
|
|
||||||
# TODO: the attribute node_size should be removed in the future
|
# TODO: the attribute node_size should be removed in the future
|
||||||
setattr(n, 'node_size', mem_stat[1])
|
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
|
||||||
setattr(n, 'fwd_flop', flop_count[0])
|
for par in n.all_input_nodes:
|
||||||
setattr(n, 'bwd_flop', flop_count[1])
|
par.meta['fwd_mem_out'] = par.meta.get('fwd_mem_out', 0) + n.meta.get('fwd_mem_in', 0)
|
||||||
setattr(n, 'fwd_tmp', mem_stat[0])
|
|
||||||
setattr(n, 'fwd_out', mem_stat[1])
|
|
||||||
setattr(n, 'bwd_tmp', mem_stat[2])
|
|
||||||
setattr(n, 'bwd_out', mem_stat[3])
|
|
||||||
n.meta['type'] = type(result)
|
n.meta['type'] = type(result)
|
||||||
|
|
||||||
|
# retain the autograd graph
|
||||||
for param in self.module.parameters():
|
for param in self.module.parameters():
|
||||||
param.grad = None
|
param.grad = None
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# Main Node running APIs
|
# Main Node running APIs
|
||||||
|
@ -125,12 +126,9 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
result (Any): The argument value that was retrieved
|
result (Any): The argument value that was retrieved
|
||||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
|
||||||
"""
|
"""
|
||||||
result = super().placeholder(target, args, kwargs)
|
return super().placeholder(target, args, kwargs), GraphInfo()
|
||||||
# A placeholder node only has activation
|
|
||||||
return result, (0, 0), (0, activation_size(result), 0, 0)
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||||
|
@ -147,10 +145,9 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
result (Any): The argument value that was retrieved
|
result (Any): The argument value that was retrieved
|
||||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
|
||||||
"""
|
"""
|
||||||
return super().get_attr(target, args, kwargs), (0, 0), (0, 0, 0, 0)
|
return super().get_attr(target, args, kwargs), GraphInfo()
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||||
|
@ -166,8 +163,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
|
|
||||||
Return
|
Return
|
||||||
result (Any): The argument value that was retrieved
|
result (Any): The argument value that was retrieved
|
||||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
|
||||||
"""
|
"""
|
||||||
assert not isinstance(target, str)
|
assert not isinstance(target, str)
|
||||||
return profile_function(target)(*args, **kwargs)
|
return profile_function(target)(*args, **kwargs)
|
||||||
|
@ -186,8 +182,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
|
|
||||||
Return
|
Return
|
||||||
result (Any): The argument value that was retrieved
|
result (Any): The argument value that was retrieved
|
||||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
|
||||||
"""
|
"""
|
||||||
return profile_method(target)(*args, **kwargs)
|
return profile_method(target)(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -205,8 +200,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
|
|
||||||
Return
|
Return
|
||||||
result (Any): The argument value that was retrieved
|
result (Any): The argument value that was retrieved
|
||||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
|
||||||
"""
|
"""
|
||||||
# Retrieve executed args and kwargs values from the environment
|
# Retrieve executed args and kwargs values from the environment
|
||||||
# Execute the method and return the result
|
# Execute the method and return the result
|
||||||
|
@ -229,10 +223,9 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
result (Any): The argument value that was retrieved
|
result (Any): The argument value that was retrieved
|
||||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
|
||||||
"""
|
"""
|
||||||
return args[0], (0, 0), (0, 0, 0, 0)
|
return args[0], GraphInfo(fwd_mem_in=activation_size(args[0]))
|
||||||
|
|
||||||
def propagate(self, *args):
|
def propagate(self, *args):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -2,8 +2,9 @@ from ... import META_COMPATIBILITY
|
||||||
if META_COMPATIBILITY:
|
if META_COMPATIBILITY:
|
||||||
from .opcount import flop_mapping
|
from .opcount import flop_mapping
|
||||||
from .tensor import MetaTensor
|
from .tensor import MetaTensor
|
||||||
from .profiler import profile_function, profile_method, profile_module, _profile
|
from .profiler import profile_function, profile_method, profile_module
|
||||||
else:
|
else:
|
||||||
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module
|
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module
|
||||||
|
|
||||||
|
from .dataflow import GraphInfo
|
||||||
from .memory import parameter_size, activation_size
|
from .memory import parameter_size, activation_size
|
||||||
|
|
|
@ -0,0 +1,136 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict
|
||||||
|
from torch.fx import Graph, Node
|
||||||
|
from .memory import activation_size
|
||||||
|
|
||||||
|
|
||||||
|
class Stage(Enum):
|
||||||
|
FORWARD = 0
|
||||||
|
LOSS = 1
|
||||||
|
BACKWARD = 2
|
||||||
|
PLACEHOLDER = 3
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GraphInfo:
|
||||||
|
"""
|
||||||
|
GraphInfo is a dataclass for MetaInfo, which measures
|
||||||
|
the execution memory cost and FLOPs with `MetaTensor`.
|
||||||
|
The dataflow analysis is conducted on a single node of the FX graph.
|
||||||
|
============================================================================
|
||||||
|
-------------------------------
|
||||||
|
| Node |
|
||||||
|
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`
|
||||||
|
placeholders saved for | | \__________ | |
|
||||||
|
backward. | | \ | |
|
||||||
|
| [fwd_tmp] ------> [bwd_tmp] | <-----
|
||||||
|
| | \_________ | | [bwd_tmp] marks the peak memory
|
||||||
|
| / \ \ | | in backward pass.
|
||||||
|
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
|
||||||
|
in [fwd_tmp] because | | | \_____ | |
|
||||||
|
it is not saved for | | | \ | |
|
||||||
|
backward. -------------------------------
|
||||||
|
============================================================================
|
||||||
|
Attributes:
|
||||||
|
fwd_flop (int): The forward FLOPs of a certain node
|
||||||
|
bwd_flop (int): The backward FLOPs of a certain node.
|
||||||
|
fwd_mem_in (int): See the above illustration.
|
||||||
|
fwd_mem_tmp (int): See the above illustration.
|
||||||
|
bwd_mem_tmp (int): See the above illustration.
|
||||||
|
bwd_mem_out (int): See the above illustration.
|
||||||
|
"""
|
||||||
|
fwd_flop: int = 0
|
||||||
|
bwd_flop: int = 0
|
||||||
|
fwd_mem_in: int = 0
|
||||||
|
fwd_mem_tmp: int = 0
|
||||||
|
bwd_mem_tmp: int = 0
|
||||||
|
bwd_mem_out: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
def is_forward(n: Node):
|
||||||
|
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
|
||||||
|
return n.meta['stage'] == Stage.FORWARD
|
||||||
|
|
||||||
|
|
||||||
|
def is_loss(n: Node):
|
||||||
|
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
|
||||||
|
return n.meta['stage'] == Stage.LOSS
|
||||||
|
|
||||||
|
|
||||||
|
def is_placeholder(n: Node):
|
||||||
|
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
|
||||||
|
return n.meta['stage'] == Stage.PLACEHOLDER
|
||||||
|
|
||||||
|
|
||||||
|
def is_backward(n: Node):
|
||||||
|
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
|
||||||
|
return n.meta['stage'] == Stage.BACKWARD
|
||||||
|
|
||||||
|
|
||||||
|
def is_saved(n: Node):
|
||||||
|
return n.meta.get('saved', False)
|
||||||
|
|
||||||
|
|
||||||
|
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||||
|
"""Analyze the autograd node dependencies and find out the memory usage.
|
||||||
|
Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`.
|
||||||
|
Nodes should have attribute `out` indicating the output of each node.
|
||||||
|
============================================================================
|
||||||
|
Placeholder ----> p o <---- We need to keep track of grad out
|
||||||
|
|\________ |
|
||||||
|
↓ ↘|
|
||||||
|
f --------> b
|
||||||
|
|\ \_____ ↑
|
||||||
|
| \ ↘ /
|
||||||
|
f f ----> b <---- Not every forward result needs to be saved for backward
|
||||||
|
| \____ ↑
|
||||||
|
↘ ↘|
|
||||||
|
f ----> b <---- Backward can be freed as soon as it is required no more.
|
||||||
|
↘ ↗
|
||||||
|
l
|
||||||
|
=============================================================================
|
||||||
|
Args:
|
||||||
|
graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
graph_info (GraphInfo): Meta information for the dataflow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _peak_memory(deps: Dict[Node, int]):
|
||||||
|
bwd_tmp = 0
|
||||||
|
for k, v in deps.items():
|
||||||
|
if v > 0:
|
||||||
|
bwd_tmp += activation_size(k.meta['out'])
|
||||||
|
return bwd_tmp
|
||||||
|
|
||||||
|
# deps is used to track all the memory dependencies of the graph.
|
||||||
|
deps = {}
|
||||||
|
graph_info = GraphInfo()
|
||||||
|
|
||||||
|
for n in graph.nodes:
|
||||||
|
n: Node
|
||||||
|
if is_saved(n) and not any(map(is_loss, n.users)):
|
||||||
|
# A forward tensor who is marked `save` but is not
|
||||||
|
# an input to `loss` should be saved during forward.
|
||||||
|
# If the tensor is a placeholder, then it belongs to `fwd_in`.
|
||||||
|
# Any `fwd_in` should be kept in memory even this function
|
||||||
|
# is checkpointed.
|
||||||
|
# Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint
|
||||||
|
# the node, `fwd_tmp` can be freed.
|
||||||
|
if is_placeholder(n):
|
||||||
|
graph_info.fwd_mem_in += activation_size(n.meta['out'])
|
||||||
|
if is_forward(n):
|
||||||
|
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
|
||||||
|
elif is_backward(n):
|
||||||
|
if len(n.users):
|
||||||
|
# liveness analysis is only used in backward
|
||||||
|
deps[n] = len(n.users)
|
||||||
|
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
|
||||||
|
for input_n in n.all_input_nodes:
|
||||||
|
if input_n in deps:
|
||||||
|
deps[input_n] -= 1
|
||||||
|
else:
|
||||||
|
# basically a backward node without user is a `grad_out` node
|
||||||
|
graph_info.bwd_mem_out += activation_size(n.meta['out'])
|
||||||
|
return graph_info
|
|
@ -1,3 +1,4 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Callable, Any, Dict, Tuple
|
from typing import Callable, Any, Dict, Tuple
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.node import Argument, Target
|
from torch.fx.node import Argument, Target
|
||||||
|
@ -6,6 +7,44 @@ from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLAC
|
||||||
|
|
||||||
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
||||||
|
|
||||||
|
|
||||||
|
# this is for compatibility use
|
||||||
|
@dataclass
|
||||||
|
class GraphInfo:
|
||||||
|
"""
|
||||||
|
GraphInfo is a dataclass for MetaInfo, which measures
|
||||||
|
the execution memory cost and FLOPs with `MetaTensor`.
|
||||||
|
The dataflow analysis is conducted on a single node of the FX graph.
|
||||||
|
============================================================================
|
||||||
|
-------------------------------
|
||||||
|
| Node |
|
||||||
|
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`
|
||||||
|
placeholders saved for | | \__________ | |
|
||||||
|
backward. | | \ | |
|
||||||
|
| [fwd_tmp] ------> [bwd_tmp] | <-----
|
||||||
|
| | \_________ | | [bwd_tmp] marks the peak memory
|
||||||
|
| / \ \ | | in backward pass.
|
||||||
|
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
|
||||||
|
in [fwd_tmp] because | | | \_____ | |
|
||||||
|
it is not saved for | | | \ | |
|
||||||
|
backward. -------------------------------
|
||||||
|
============================================================================
|
||||||
|
Attributes:
|
||||||
|
fwd_flop (int): The forward FLOPs of a certain node
|
||||||
|
bwd_flop (int): The backward FLOPs of a certain node.
|
||||||
|
fwd_mem_in (int): See the above illustration.
|
||||||
|
fwd_mem_tmp (int): See the above illustration.
|
||||||
|
bwd_mem_tmp (int): See the above illustration.
|
||||||
|
bwd_mem_out (int): See the above illustration.
|
||||||
|
"""
|
||||||
|
fwd_flop: int = 0
|
||||||
|
bwd_flop: int = 0
|
||||||
|
fwd_mem_in: int = 0
|
||||||
|
fwd_mem_tmp: int = 0
|
||||||
|
bwd_mem_tmp: int = 0
|
||||||
|
bwd_mem_out: int = 0
|
||||||
|
|
||||||
|
|
||||||
CALL_FUNCTION_MSG = \
|
CALL_FUNCTION_MSG = \
|
||||||
"""
|
"""
|
||||||
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
|
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
|
||||||
|
@ -59,7 +98,7 @@ def profile_function(target: 'Target') -> Callable:
|
||||||
else:
|
else:
|
||||||
profiler = meta_profiler_function.get(target.__name__)
|
profiler = meta_profiler_function.get(target.__name__)
|
||||||
fwd_flop, _ = profiler(*args, **kwargs)
|
fwd_flop, _ = profiler(*args, **kwargs)
|
||||||
return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
|
return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
|
||||||
|
|
||||||
f.__name__ = target.__name__
|
f.__name__ = target.__name__
|
||||||
func = target
|
func = target
|
||||||
|
@ -88,7 +127,7 @@ def profile_method(target: 'Target') -> Callable:
|
||||||
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
|
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
|
||||||
fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out)
|
fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out)
|
||||||
fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out)
|
fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out)
|
||||||
return out, (0, 0), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
|
return out, GraphInfo(0, 0, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
@ -118,7 +157,7 @@ def profile_module(module: torch.nn.Module) -> Callable:
|
||||||
fwd_out = activation_size(out)
|
fwd_out = activation_size(out)
|
||||||
profiler = meta_profiler_module.get(type(module))
|
profiler = meta_profiler_module.get(type(module))
|
||||||
fwd_flop, _ = profiler(module, *args, **kwargs)
|
fwd_flop, _ = profiler(module, *args, **kwargs)
|
||||||
return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
|
return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
|
||||||
|
|
||||||
f.__name__ = module.__class__.__name__
|
f.__name__ = module.__class__.__name__
|
||||||
func = module.forward
|
func = module.forward
|
||||||
|
|
|
@ -14,12 +14,10 @@ if META_COMPATIBILITY:
|
||||||
|
|
||||||
INPLACE_ATEN = [
|
INPLACE_ATEN = [
|
||||||
aten.add_.Tensor,
|
aten.add_.Tensor,
|
||||||
aten.add.Tensor,
|
|
||||||
aten.sub_.Tensor,
|
aten.sub_.Tensor,
|
||||||
aten.div_.Tensor,
|
aten.div_.Tensor,
|
||||||
aten.div_.Scalar,
|
aten.div_.Scalar,
|
||||||
aten.mul_.Tensor,
|
aten.mul_.Tensor,
|
||||||
aten.mul.Tensor,
|
|
||||||
aten.bernoulli_.float,
|
aten.bernoulli_.float,
|
||||||
|
|
||||||
# inplace reshaping
|
# inplace reshaping
|
||||||
|
|
|
@ -1,13 +1,16 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import auto
|
||||||
from typing import Callable, Any, Dict, Tuple
|
from typing import Callable, Any, Dict, Tuple
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import Graph
|
from torch.fx import Graph, Node
|
||||||
from torch.fx.node import Argument, Target
|
from torch.fx.node import Argument, Target
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
from .memory import activation_size, INPLACE_ATEN, WEIRD_OPS
|
from .dataflow import autograd_graph_analysis, Stage
|
||||||
|
from .memory import WEIRD_OPS
|
||||||
from .tensor import MetaTensor
|
from .tensor import MetaTensor
|
||||||
from .opcount import flop_mapping
|
from .opcount import flop_mapping
|
||||||
|
|
||||||
__all__ = ['profile_function', 'profile_module', 'profile_method', '_profile']
|
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
||||||
|
|
||||||
|
|
||||||
def normalize_tuple(x):
|
def normalize_tuple(x):
|
||||||
|
@ -20,8 +23,9 @@ def is_autogradable(x):
|
||||||
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
||||||
|
|
||||||
|
|
||||||
def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...]:
|
||||||
"""Profile a Callable function with args and kwargs.
|
"""
|
||||||
|
Profile a Callable function with args and kwargs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target (Callable): A Callable function
|
target (Callable): A Callable function
|
||||||
|
@ -29,25 +33,32 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
||||||
kwargs (Any): Argument
|
kwargs (Any): Argument
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
out (Tuple[Any, ...]): The argument value that was retrieved
|
out (Tuple[Any, ...]): The argument value that was retrieved.
|
||||||
flop_count (Tuple[int, ...]): The flop count for (fwd_flop, bwd_flop).
|
meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||||
mem_stat (Tuple[int, ...]): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
|
||||||
"""
|
"""
|
||||||
|
# This subgraph traces aten level ops inside one node.
|
||||||
|
subgraph = Graph()
|
||||||
|
|
||||||
|
# `flop_count`` serves as a global dictionary to store results.
|
||||||
flop_count = {
|
flop_count = {
|
||||||
'f': 0,
|
Stage.FORWARD: 0,
|
||||||
'l': 0,
|
Stage.LOSS: 0,
|
||||||
'b': 0,
|
Stage.BACKWARD: 0,
|
||||||
}
|
}
|
||||||
temp = {
|
|
||||||
'f': [],
|
|
||||||
'l': [],
|
|
||||||
'b': [],
|
|
||||||
}
|
|
||||||
stage = 'f'
|
|
||||||
|
|
||||||
|
# `stage` will mark the stage of autograd from outside scope.
|
||||||
|
stage = Stage.FORWARD
|
||||||
|
|
||||||
|
# FlopTensor not only get the flop statistics of a single node,
|
||||||
|
# it also build a full autograd graph for this node.
|
||||||
|
# This makes sure we can analyze the dependencies of memory, and
|
||||||
|
# decide which forward intermediate results should be kept until
|
||||||
|
# backward is executed.
|
||||||
|
# Hopefully, this attempt will provide a better estimation of memory.
|
||||||
class FlopTensor(MetaTensor):
|
class FlopTensor(MetaTensor):
|
||||||
|
|
||||||
|
_node: Node
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
if self.grad_fn:
|
if self.grad_fn:
|
||||||
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})"
|
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})"
|
||||||
|
@ -56,42 +67,76 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
|
|
||||||
|
def get_node(x):
|
||||||
|
return None if not hasattr(x, '_node') else x._node
|
||||||
|
|
||||||
|
args_node = tree_map(get_node, args)
|
||||||
|
kwargs_node = tree_map(get_node, kwargs)
|
||||||
|
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
|
||||||
|
|
||||||
def unwrap(x):
|
def unwrap(x):
|
||||||
|
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor`
|
||||||
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
||||||
x = FlopTensor(x.to('meta'))
|
x = FlopTensor(x.to('meta'))
|
||||||
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
|
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
|
||||||
|
|
||||||
def to_meta(x):
|
|
||||||
return x.to('meta') if isinstance(x, torch.Tensor) else x
|
|
||||||
|
|
||||||
args = tree_map(unwrap, args)
|
args = tree_map(unwrap, args)
|
||||||
kwargs = tree_map(unwrap, kwargs)
|
kwargs = tree_map(unwrap, kwargs)
|
||||||
|
|
||||||
# run aten for backend=CPU but actually on backend=Meta
|
# run aten for backend=CPU but actually on backend=Meta
|
||||||
out = func(*args, **kwargs)
|
out = func(*args, **kwargs)
|
||||||
flop_count[stage] += flop_mapping[func](args, normalize_tuple(out))
|
flop_count[stage] += flop_mapping[func](args, normalize_tuple(out))
|
||||||
if func not in INPLACE_ATEN:
|
node.meta['out'] = normalize_tuple(out)
|
||||||
temp[stage].append(tree_map(to_meta, normalize_tuple(out)))
|
node.meta['stage'] = stage
|
||||||
|
|
||||||
def wrap(x):
|
def wrap(x):
|
||||||
return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x
|
return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x
|
||||||
|
|
||||||
return tree_map(wrap, out)
|
def set_node(x):
|
||||||
|
x._node = node
|
||||||
|
|
||||||
|
out = tree_map(wrap, out)
|
||||||
|
tree_map(set_node, out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
# `WEIRD_OPS` are tough to handle because they don't accept autograd
|
||||||
|
# on meta tensor.
|
||||||
if target not in WEIRD_OPS:
|
if target not in WEIRD_OPS:
|
||||||
|
|
||||||
def wrap(x):
|
def wrap(x):
|
||||||
return FlopTensor(
|
return FlopTensor(x.detach().requires_grad_(
|
||||||
x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
|
True)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def wrap(x):
|
def wrap(x):
|
||||||
return FlopTensor(
|
return FlopTensor(x.detach().requires_grad_(
|
||||||
x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
|
False)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x
|
||||||
|
|
||||||
|
# Basically, we need to detach the args and kwargs from the outer graph.
|
||||||
args = tree_map(wrap, args)
|
args = tree_map(wrap, args)
|
||||||
kwargs = tree_map(wrap, kwargs)
|
kwargs = tree_map(wrap, kwargs)
|
||||||
|
|
||||||
|
def set_placeholder(x):
|
||||||
|
if isinstance(x, FlopTensor):
|
||||||
|
x._node = subgraph.create_node('placeholder',
|
||||||
|
'placeholder', (subgraph._root,),
|
||||||
|
name=subgraph._graph_namespace.create_name('input', x._tensor))
|
||||||
|
x._node.meta['stage'] = Stage.PLACEHOLDER
|
||||||
|
x._node.meta['out'] = (x._tensor,)
|
||||||
|
|
||||||
|
tree_map(set_placeholder, args)
|
||||||
|
tree_map(set_placeholder, kwargs)
|
||||||
|
|
||||||
|
def pack(x):
|
||||||
|
if isinstance(x, FlopTensor):
|
||||||
|
x._node.meta['saved'] = True
|
||||||
|
return x
|
||||||
|
|
||||||
|
def unpack(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
# mark saved tensors with saved_tensors_hooks
|
||||||
|
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
||||||
if isinstance(target, str):
|
if isinstance(target, str):
|
||||||
# args[0] is the `self` object for this method call
|
# args[0] is the `self` object for this method call
|
||||||
self_obj, *args_tail = args
|
self_obj, *args_tail = args
|
||||||
|
@ -99,23 +144,21 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
||||||
else:
|
else:
|
||||||
out = target(*args, **kwargs)
|
out = target(*args, **kwargs)
|
||||||
|
|
||||||
|
# If the output is not a floating point `torch.Tensor` or it does not
|
||||||
|
# requires grad, then we should not run backward for this node.
|
||||||
if is_autogradable(out) and out.requires_grad:
|
if is_autogradable(out) and out.requires_grad:
|
||||||
stage = 'l'
|
stage = Stage.LOSS
|
||||||
loss = out.sum()
|
loss = out.sum()
|
||||||
stage = 'b'
|
stage = Stage.BACKWARD
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
fwd_flop = flop_count['f']
|
graph_info = autograd_graph_analysis(subgraph)
|
||||||
bwd_flop = flop_count['b']
|
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Stage.FORWARD], flop_count[Stage.BACKWARD]
|
||||||
|
|
||||||
fwd_tmp = max(map(activation_size, temp['f'][:-1])) if len(temp['f'][:-1]) else 0
|
|
||||||
fwd_out = activation_size(temp['f'][-1]) if len(temp['f']) else 0
|
|
||||||
bwd_tmp = max(map(activation_size, temp['b'])) if len(temp['b']) else 0
|
|
||||||
|
|
||||||
def unwrap(x):
|
def unwrap(x):
|
||||||
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
|
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
|
||||||
|
|
||||||
return tree_map(unwrap, out), (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, 0)
|
return tree_map(unwrap, out), graph_info
|
||||||
|
|
||||||
|
|
||||||
def profile_function(target: 'Target') -> Callable:
|
def profile_function(target: 'Target') -> Callable:
|
||||||
|
@ -130,17 +173,15 @@ def profile_function(target: 'Target') -> Callable:
|
||||||
Examples:
|
Examples:
|
||||||
>>> input = torch.rand(100, 100, 100, 100, device='meta')
|
>>> input = torch.rand(100, 100, 100, 100, device='meta')
|
||||||
>>> func = torch.nn.functional.relu
|
>>> func = torch.nn.functional.relu
|
||||||
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False)
|
>>> output, meta_info = profile_function(func)(input, inplace=False)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||||
if kwargs.get('inplace', False):
|
|
||||||
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
|
# If there is an argument that this `call_function` is inplace, we should
|
||||||
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
|
# skip the autograd profiling.
|
||||||
out = func(*args, **kwargs)
|
out, meta = _profile(func, *args, **kwargs)
|
||||||
return out, (0, 0), (0, 0, 0, 0)
|
return out, meta
|
||||||
out, flop_count, mem_stat = _profile(func, *args, **kwargs)
|
|
||||||
return out, flop_count, mem_stat
|
|
||||||
|
|
||||||
f.__name__ = target.__name__
|
f.__name__ = target.__name__
|
||||||
func = target
|
func = target
|
||||||
|
@ -156,8 +197,8 @@ def profile_method(target: 'Target') -> Callable:
|
||||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||||
# execute the method and return the result
|
# execute the method and return the result
|
||||||
assert isinstance(target, str), f'{target} instance is not str.'
|
assert isinstance(target, str), f'{target} instance is not str.'
|
||||||
out, flop_count, mem_stat = _profile(target, *args, **kwargs)
|
out, meta = _profile(target, *args, inplace=False, **kwargs)
|
||||||
return out, flop_count, mem_stat
|
return out, meta
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
@ -174,17 +215,15 @@ def profile_module(module: torch.nn.Module) -> Callable:
|
||||||
Example:
|
Example:
|
||||||
>>> input = torch.rand(4, 3, 224, 224, device='meta')
|
>>> input = torch.rand(4, 3, 224, 224, device='meta')
|
||||||
>>> mod = torch.nn.Conv2d(3, 128, 3)
|
>>> mod = torch.nn.Conv2d(3, 128, 3)
|
||||||
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input)
|
>>> output, meta_info = profile_module(mod)(input)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||||
if getattr(module, 'inplace', False):
|
|
||||||
args = tree_map(lambda x: x.to('meta'), args)
|
# If there is an argument that this `call_module` is inplace, we should
|
||||||
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
|
# skip the autograd profiling.
|
||||||
out = func(*args, **kwargs)
|
out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs)
|
||||||
return out, (out.numel(), out.numel()), (0, 0, 0, 0)
|
return out, meta
|
||||||
out, flop_count, mem_stat = _profile(func, *args, **kwargs)
|
|
||||||
return out, flop_count, mem_stat
|
|
||||||
|
|
||||||
f.__name__ = module.__class__.__name__
|
f.__name__ = module.__class__.__name__
|
||||||
func = module.forward
|
func = module.forward
|
||||||
|
|
Loading…
Reference in New Issue