[fx] Add concrete info prop (#1677)

* [fx] concreteinfoprop

* [fx] add concreteinfoprop

* [fx] modify docstring of ConcreteInfoProp

* [fx] fix device error

* [fx] modify parameter calculation

* [fx] modify parameters calculation
pull/1678/head
Boyuan Yao 2 years ago committed by GitHub
parent 1df98d5b66
commit 132b4306b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,3 +1,4 @@
from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass
from .meta_info_prop import MetaInfoProp from .meta_info_prop import MetaInfoProp
from .concrete_info_prop import ConcreteInfoProp

@ -0,0 +1,290 @@
from dataclasses import asdict
from colossalai.fx.profiler import GraphInfo
import torch
import torch.fx
from torch.fx.node import Node, Argument, Target
from torch.utils._pytree import tree_flatten
from typing import Any, List, Tuple, NamedTuple, Dict, Optional
from torch.fx._compatibility import compatibility
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
from torch.fx.graph_module import GraphModule
@compatibility(is_backward_compatible=True)
class ConcreteInfoProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node with concrete tensor and record the memory
usage, execution time of forward and backward, and type of the result into
the corresponding node.
Usage:
BATCH_SIZE = 2
DIM_IN = 4
DIM_HIDDEN = 16
DIM_OUT = 16
model = torch.nn.Sequential(
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
).cuda()
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="cuda")
gm = symbolic_trace(model)
interp = ConcreteInfoProp(gm)
interp.run(input_sample)
print(interp.summary(unit='kb'))
output of above code is
Op type Op Forward time Backward time SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- ----------------------- ------------------------ ------------- --------- --------- --------- ---------
placeholder input_1 0.0 s 0.0 s False 0.00 KB 0.00 KB 0.00 KB 0.00 KB
call_module _0 0.0003993511199951172 s 0.00706791877746582 s False 0.50 KB 0.00 KB 0.03 KB 0.66 KB
call_module _1 6.29425048828125e-05 s 0.00018286705017089844 s False 0.50 KB 0.00 KB 0.12 KB 0.81 KB
output output 0.0 s 0.0 s True 0.00 KB 0.00 KB 0.00 KB 0.00 KB
Args:
module (GraphModule): The module to be executed
"""
_is_proped: bool = False
def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any:
"""Customized run for ConcreteInfoProp
We need to store the device in self.device
Args:
*args: The arguments to the Module to run, in positional order
initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
This is a dict mapping `Node` to any value. This can be used, for example, to
pre-populate results for certain `Nodes` so as to do only partial evaluation within
the interpreter.
enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and
process_outputs function first before using them.
Returns:
Any: The value returned from executing the Module
"""
flatten_args, _ = tree_flatten(args)
self.device = next(item for item in flatten_args if hasattr(item, "device")).device
return super().run(*args, initial_env, enable_io_processing)
@compatibility(is_backward_compatible=True)
def run_node(self, n: Node) -> Any:
"""
Run a specific node ``n`` and return the result.
Calls into placeholder, get_attr, call_function,
call_method, call_module, or output depending
on ``node.op``
Args:
n (Node): The Node to execute
Returns:
Any: The result of executing ``n``
"""
self._is_proped = True
result, meta_info = super().run_node(n)
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
n.meta['type'] = type(result)
# retain the autograd graph
for param in self.module.parameters():
param.grad = None
return result
# Main Node running APIs
@compatibility(is_backward_compatible=True)
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
arguments passed to ``run`` and this method returns
next() on that iterator.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Returns:
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and forward & backward time.
"""
return super().placeholder(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return:
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return super().get_attr(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and forward & backward time.
"""
assert not isinstance(target, str)
return profile_function(target, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and forward & backward time.
"""
return profile_method(target, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and forward & backward time.
"""
# Retrieve executed args and kwargs values from the environment
# Execute the method and return the result
assert isinstance(target, str)
submod = self.fetch_attr(target)
return profile_module(submod, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return:
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and forward & backward time.
"""
return args[0], GraphInfo(save_fwd_in=True)
def propagate(self, *args):
"""
Run `module` via interpretation and return the result and
record the shape and type of each node.
Args:
*args (Tensor): the sample input.
Returns:
Any: The value returned from executing the Module
"""
return super().run(*args)
def summary(self, unit: str = 'MB') -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
to be installed.
"""
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
try:
from tabulate import tabulate
except ImportError:
print("`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library.")
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
# Build up a list of summary information for each node
node_summaries: List[List[Any]] = []
def mem_repr(mem: int) -> str:
unit_divisor_map = {
'kb': 1024,
'mb': 1024**2,
'gb': 1024**3,
'tb': 1024**4,
}
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
def time_repr(time: float):
return f"{time:,} s"
for node in self.module.graph.nodes:
node: Node
node_summaries.append([
node.op,
str(node),
time_repr(node.meta['fwd_time']),
time_repr(node.meta['bwd_time']),
node.meta['save_fwd_in'],
mem_repr(node.meta['fwd_mem_out']),
mem_repr(node.meta['fwd_mem_tmp']),
mem_repr(node.meta['bwd_mem_out']),
mem_repr(node.meta['bwd_mem_tmp']),
])
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
'Op type',
'Op',
'Forward time',
'Backward time',
'SAVE_FWD_IN',
'FWD_OUT',
'FWD_TMP',
'BWD_OUT',
'BWD_TMP',
]
return tabulate(node_summaries, headers=headers, stralign='right')

@ -1,5 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from functools import partial
from typing import Dict from typing import Dict
from torch.fx import Graph, Node from torch.fx import Graph, Node
from .memory import activation_size, is_inplace from .memory import activation_size, is_inplace
@ -33,8 +34,10 @@ class GraphInfo:
------------------------------- -------------------------------
============================================================================ ============================================================================
Attributes: Attributes:
fwd_flop (int): The forward FLOPs of a certain node fwd_flop (int): The forward FLOPs of a certain node.
fwd_time (float): The real forward time (s) of a certain node.
bwd_flop (int): The backward FLOPs of a certain node. bwd_flop (int): The backward FLOPs of a certain node.
bwd_time (float): The real backward time (s) of a certain node.
save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes. save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
fwd_mem_tmp (int): See the above illustration. fwd_mem_tmp (int): See the above illustration.
fwd_mem_out (int): See the above illustration. fwd_mem_out (int): See the above illustration.
@ -42,7 +45,9 @@ class GraphInfo:
bwd_mem_out (int): See the above illustration. bwd_mem_out (int): See the above illustration.
""" """
fwd_flop: int = 0 fwd_flop: int = 0
fwd_time: float = 0.0
bwd_flop: int = 0 bwd_flop: int = 0
bwd_time: float = 0.0
save_fwd_in: bool = False save_fwd_in: bool = False
fwd_mem_tmp: int = 0 fwd_mem_tmp: int = 0
fwd_mem_out: int = 0 fwd_mem_out: int = 0

@ -5,10 +5,11 @@ from torch.fx import Graph, Node
from torch.fx.node import Argument, Target from torch.fx.node import Argument, Target
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from .dataflow import autograd_graph_analysis, is_phase, Phase, GraphInfo from .dataflow import autograd_graph_analysis, is_phase, Phase, GraphInfo
from .memory import activation_size from .memory import activation_size, parameter_size
from .constant import ALIAS_ATEN from .constant import ALIAS_ATEN
from .tensor import MetaTensor from .tensor import MetaTensor
from .opcount import flop_mapping from .opcount import flop_mapping
import time
__all__ = ['profile_function', 'profile_module', 'profile_method'] __all__ = ['profile_function', 'profile_module', 'profile_method']
@ -27,33 +28,112 @@ def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point() return isinstance(x, torch.Tensor) and x.is_floating_point()
# super-dainiu: def detach_variables(x):
# x.detach() will change the unique identifier of data_ptr
# we need to handle this in a stupid way
def detach(x):
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
requires_grad = x.requires_grad requires_grad = x.requires_grad
x.requires_grad_(False) x = x.detach()
x.requires_grad_(requires_grad) x.requires_grad = requires_grad
return x
def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
""" """Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30
Profile a Callable function with args and kwargs on concrete devices. To profile the actual forward memory, we first run target in the context torch.no_grad() to get
the fwd_mem_out, then we run target with grad enable to found the extra memory stored in the memory
by memory allocated minus the fwd_mem_out.
To profile the actual backward memory, we first make dummy gradient for torch.autograd.backward, then
find the bwd_mem_tmp with memory peak during the process minus bwd_mem_out(it is actually equal to size
of args and kwargs).
We also add time stamps to profile the real forward and backward time.
Args: Args:
target (Callable): A Callable function target (Callable): A Callable function
args (Any): Argument args (Any): Arguments
kwargs (Any): Argument kwargs (Any): Arguments
Raises:
NotImplementedError: TODO(yby)
Returns: Returns:
out (Tuple[Any, ...]): The argument value that was retrieved. Tuple[Tuple[Any, ...], GraphInfo]: Output for next node & memory cost and real forward and backward
meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`. time.
""" """
raise NotImplementedError
graphinfo = GraphInfo()
# detach input from the graph
args = tree_map(detach_variables, args)
kwargs = tree_map(detach_variables, kwargs)
if isinstance(target, str):
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
# calculate fwd_mem_out
mem_stamp0 = torch.cuda.memory_allocated()
with torch.no_grad():
out = getattr(self_obj, target)(*args_tail, **kwargs)
mem_stamp1 = torch.cuda.memory_allocated()
graphinfo.fwd_mem_out = mem_stamp1 - mem_stamp0
del out
# calculate fwd_mem_tmp & fwd_time
mem_stamp0 = torch.cuda.memory_allocated()
fwd_time0 = time.time()
out = getattr(self_obj, target)(*args_tail, **kwargs)
fwd_time1 = time.time()
graphinfo.fwd_time = fwd_time1 - fwd_time0
mem_stamp1 = torch.cuda.memory_allocated()
graphinfo.fwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.fwd_mem_out
# calculate bwd_mem_tmp & bwd_time
grad_tensors = tree_map(lambda x: torch.ones_like(x) if isinstance(x, torch.Tensor) else None, out)
torch.cuda.reset_peak_memory_stats()
mem_stamp0 = torch.cuda.memory_allocated()
bwd_time0 = time.time()
torch.autograd.backward(out, grad_tensors=grad_tensors)
bwd_time1 = time.time()
graphinfo.bwd_time = bwd_time1 - bwd_time0
mem_stamp1 = torch.cuda.max_memory_allocated()
# calculate bwd memory stats
# NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation
graphinfo.bwd_mem_out = activation_size(args) + activation_size(kwargs)
graphinfo.bwd_mem_out += parameter_size(target.__self__) if hasattr(target.__self__, "parameters") else 0
graphinfo.bwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.bwd_mem_out
else:
# calculate fwd_mem_out
mem_stamp0 = torch.cuda.memory_allocated()
with torch.no_grad():
out = target(*args, **kwargs)
mem_stamp1 = torch.cuda.memory_allocated()
graphinfo.fwd_mem_out = mem_stamp1 - mem_stamp0
del out
# calculate fwd_mem_tmp & fwd_time
mem_stamp0 = torch.cuda.memory_allocated()
fwd_time0 = time.time()
out = target(*args, **kwargs)
fwd_time1 = time.time()
graphinfo.fwd_time = fwd_time1 - fwd_time0
mem_stamp1 = torch.cuda.memory_allocated()
graphinfo.fwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.fwd_mem_out
# calculate bwd_mem_tmp & bwd_time
grad_tensors = tree_map(lambda x: torch.ones_like(x) if isinstance(x, torch.Tensor) else None, out)
torch.cuda.reset_peak_memory_stats()
mem_stamp0 = torch.cuda.memory_allocated()
bwd_time0 = time.time()
torch.autograd.backward(out, grad_tensors=grad_tensors)
bwd_time1 = time.time()
graphinfo.bwd_time = bwd_time1 - bwd_time0
mem_stamp1 = torch.cuda.max_memory_allocated()
# calculate bwd memory stats
# NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation
graphinfo.bwd_mem_out = activation_size(args) + activation_size(kwargs)
graphinfo.bwd_mem_out += parameter_size(target.__self__) if hasattr(target.__self__, "parameters") else 0
graphinfo.bwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.bwd_mem_out
return tree_map(detach_variables, out), graphinfo
def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
@ -135,7 +215,6 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
name=subgraph._graph_namespace.create_name('input', x._tensor)) name=subgraph._graph_namespace.create_name('input', x._tensor))
x._node.meta['phase'] = Phase.PLACEHOLDER x._node.meta['phase'] = Phase.PLACEHOLDER
x._node.meta['saved_tensor'] = [] x._node.meta['saved_tensor'] = []
detach(x)
return x return x
# Basically, we need to detach the args and kwargs from the outer graph. # Basically, we need to detach the args and kwargs from the outer graph.
@ -206,12 +285,26 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
kwargs['inplace'] = False kwargs['inplace'] = False
if device == 'meta': if device == 'meta':
out, meta = _profile_meta(func, *args, **kwargs) out, meta = _profile_meta(func, *args, **kwargs)
else:
out, meta = _profile_concrete(func, *args, **kwargs) # currently we set the fwd_mem_tmp of ReLU to zero
if inplace:
if target in [torch.nn.functional.relu]: if target in [torch.nn.functional.relu]:
meta.save_fwd_in = False meta.save_fwd_in = False
meta.bwd_mem_out = 0 meta.bwd_mem_out = 0
meta.fwd_mem_tmp = 0
else:
out, meta = _profile_concrete(func, *args, **kwargs)
# find the grad for parameter in args and kwargs
param_size = 0
def get_param_size(x):
if isinstance(x, torch.nn.parameter):
param_size += activation_size(x)
tree_map(get_param_size, args)
tree_map(get_param_size, kwargs)
meta.bwd_mem_out -= param_size
return out, meta return out, meta
f.__name__ = target.__name__ f.__name__ = target.__name__
@ -257,18 +350,25 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
# If there is an argument that this `call_module` is inplace, we should # If there is an argument that this `call_module` is inplace, we should
# still run the profiling but discard some results regarding `module`. # still run the profiling but discard some results regarding `module`.
inplace = getattr(module, 'inplace', False) inplace = getattr(module, 'inplace', False)
# calculate parameter size
param_size = parameter_size(module)
if inplace: if inplace:
module.inplace = False module.inplace = False
if device == 'meta': if device == 'meta':
out, meta = _profile_meta(func, *args, **kwargs) out, meta = _profile_meta(func, *args, **kwargs)
else:
out, meta = _profile_concrete(func, *args, **kwargs) # currently we set the fwd_mem_tmp of ReLU to zero
if inplace: if type(module) in [torch.nn.modules.activation.ReLU]:
# super-dainiu: experiments on mobilenet_v2 shows that `torch.nn.ReLU`
# is the only inplace activation function that discard its input.
if type(module) in [torch.nn.ReLU]:
meta.save_fwd_in = False meta.save_fwd_in = False
meta.bwd_mem_out = 0 meta.bwd_mem_out = 0
meta.fwd_mem_tmp = 0
else:
out, meta = _profile_concrete(func, *args, **kwargs)
# grad for param will not be counted
meta.bwd_mem_out -= param_size
return out, meta return out, meta
f.__name__ = module.__class__.__name__ f.__name__ = module.__class__.__name__

Loading…
Cancel
Save