[fx/profiler] provide a table of summary. (#1634)

* [fx/profiling] provide summary for MetaInfoProp.

* [fx/profiler] provide a table of summary.

* [fx] optimize table repr.
pull/1643/head
Super Daniel 2022-09-23 18:12:43 +08:00 committed by GitHub
parent 95c35f73bd
commit 04bbabeea8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 83 additions and 14 deletions

View File

@ -1,2 +1,3 @@
from .tracer import ColoTracer, meta_trace
from .graph_module import ColoGraphModule
from .passes import MetaInfoProp

View File

@ -1,2 +1,3 @@
from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass
from .meta_info_prop import MetaInfoProp

View File

@ -4,7 +4,7 @@ import torch
import torch.fx
from torch.fx.node import Node, Argument, Target
from torch.utils._pytree import tree_map
from typing import Any, Tuple, NamedTuple, Dict
from typing import Any, List, Tuple, NamedTuple, Dict
from torch.fx._compatibility import compatibility
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
@ -48,28 +48,33 @@ class MetaInfoProp(torch.fx.Interpreter):
Usage:
BATCH_SIZE = 2
DIM_IN = 4
DIM_HIDDEN = 16
DIM_OUT = 16
model = torch.nn.Linear(DIM_IN, DIM_OUT)
model = torch.nn.Sequential(
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
)
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
orig_output = model(input_sample)
gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample)
for node in gm.graph.nodes:
print(node.name, node.meta['tensor_meta'].dtype,
node.meta['tensor_meta'].shape, node.meta['tensor_meta'].numel)
interp = MetaInfoProp(gm)
interp.run(input_sample)
print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB
# output of above code is
# input_1 torch.float32 torch.Size([2, 4]) 8
# weight torch.float32 torch.Size([16, 4]) 64
# bias torch.float32 torch.Size([16]) 16
# linear torch.float32 torch.Size([2, 16]) 32
# output torch.float32 torch.Size([2, 16]) 32
Op type Op Forward FLOPs Backward FLOPs SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- --------------- ---------------- ------------- --------- --------- --------- ---------
placeholder input_1 0 FLOPs 0 FLOPs False 0.00 KB 0.00 KB 0.00 KB 0.00 KB
call_module _0 128 FLOPs 288 FLOPs True 0.12 KB 0.00 KB 0.34 KB 0.00 KB
call_module _1 512 FLOPs 1,056 FLOPs True 0.12 KB 0.00 KB 1.19 KB 0.00 KB
output output 0 FLOPs 0 FLOPs True 0.00 KB 0.00 KB 0.00 KB 0.00 KB
Args:
module (GraphModule): The module to be executed
"""
_is_proped: bool = False
@compatibility(is_backward_compatible=True)
def run_node(self, n: Node) -> Any:
"""
@ -84,6 +89,7 @@ class MetaInfoProp(torch.fx.Interpreter):
Returns:
Any: The result of executing ``n``
"""
self._is_proped = True
result, meta_info = super().run_node(n)
def extract_tensor_meta(obj):
@ -236,3 +242,64 @@ class MetaInfoProp(torch.fx.Interpreter):
Any: The value returned from executing the Module
"""
return super().run(*args)
def summary(self, unit: str = 'MB') -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
to be installed.
"""
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
try:
from tabulate import tabulate
except ImportError:
print("`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library.")
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
# Build up a list of summary information for each node
node_summaries: List[List[Any]] = []
def mem_repr(mem: int) -> str:
unit_divisor_map = {
'kb': 1024,
'mb': 1024**2,
'gb': 1024**3,
'tb': 1024**4,
}
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
def flops_repr(flop: int) -> str:
return f"{flop:,} FLOPs"
for node in self.module.graph.nodes:
node: Node
node_summaries.append([
node.op,
str(node),
flops_repr(node.meta['fwd_flop']),
flops_repr(node.meta['bwd_flop']),
node.meta['save_fwd_in'],
mem_repr(node.meta['fwd_mem_out']),
mem_repr(node.meta['fwd_mem_tmp']),
mem_repr(node.meta['bwd_mem_out']),
mem_repr(node.meta['bwd_mem_tmp']),
])
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
'Op type',
'Op',
'Forward FLOPs',
'Backward FLOPs',
'SAVE_FWD_IN',
'FWD_OUT',
'FWD_TMP',
'BWD_OUT',
'BWD_TMP',
]
return tabulate(node_summaries, headers=headers, stralign='right')