mirror of https://github.com/hpcaitech/ColossalAI
[fx/profiler] provide a table of summary. (#1634)
* [fx/profiling] provide summary for MetaInfoProp. * [fx/profiler] provide a table of summary. * [fx] optimize table repr.pull/1643/head
parent
95c35f73bd
commit
04bbabeea8
|
@ -1,2 +1,3 @@
|
|||
from .tracer import ColoTracer, meta_trace
|
||||
from .graph_module import ColoGraphModule
|
||||
from .passes import MetaInfoProp
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
|
||||
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass
|
||||
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass
|
||||
from .meta_info_prop import MetaInfoProp
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
import torch.fx
|
||||
from torch.fx.node import Node, Argument, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
from typing import Any, Tuple, NamedTuple, Dict
|
||||
from typing import Any, List, Tuple, NamedTuple, Dict
|
||||
from torch.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
|
||||
|
||||
|
@ -48,28 +48,33 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
Usage:
|
||||
BATCH_SIZE = 2
|
||||
DIM_IN = 4
|
||||
DIM_HIDDEN = 16
|
||||
DIM_OUT = 16
|
||||
model = torch.nn.Linear(DIM_IN, DIM_OUT)
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
|
||||
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
|
||||
)
|
||||
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
|
||||
orig_output = model(input_sample)
|
||||
gm = symbolic_trace(model)
|
||||
MetaInfoProp(gm).run(input_sample)
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
print(node.name, node.meta['tensor_meta'].dtype,
|
||||
node.meta['tensor_meta'].shape, node.meta['tensor_meta'].numel)
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.run(input_sample)
|
||||
print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB
|
||||
|
||||
|
||||
# output of above code is
|
||||
# input_1 torch.float32 torch.Size([2, 4]) 8
|
||||
# weight torch.float32 torch.Size([16, 4]) 64
|
||||
# bias torch.float32 torch.Size([16]) 16
|
||||
# linear torch.float32 torch.Size([2, 16]) 32
|
||||
# output torch.float32 torch.Size([2, 16]) 32
|
||||
Op type Op Forward FLOPs Backward FLOPs SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP
|
||||
----------- ------- --------------- ---------------- ------------- --------- --------- --------- ---------
|
||||
placeholder input_1 0 FLOPs 0 FLOPs False 0.00 KB 0.00 KB 0.00 KB 0.00 KB
|
||||
call_module _0 128 FLOPs 288 FLOPs True 0.12 KB 0.00 KB 0.34 KB 0.00 KB
|
||||
call_module _1 512 FLOPs 1,056 FLOPs True 0.12 KB 0.00 KB 1.19 KB 0.00 KB
|
||||
output output 0 FLOPs 0 FLOPs True 0.00 KB 0.00 KB 0.00 KB 0.00 KB
|
||||
Args:
|
||||
module (GraphModule): The module to be executed
|
||||
|
||||
"""
|
||||
|
||||
_is_proped: bool = False
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def run_node(self, n: Node) -> Any:
|
||||
"""
|
||||
|
@ -84,6 +89,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
Returns:
|
||||
Any: The result of executing ``n``
|
||||
"""
|
||||
self._is_proped = True
|
||||
result, meta_info = super().run_node(n)
|
||||
|
||||
def extract_tensor_meta(obj):
|
||||
|
@ -236,3 +242,64 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
Any: The value returned from executing the Module
|
||||
"""
|
||||
return super().run(*args)
|
||||
|
||||
def summary(self, unit: str = 'MB') -> str:
|
||||
"""
|
||||
Summarizes the memory and FLOPs statistics of the `GraphModule` in
|
||||
tabular format. Note that this API requires the ``tabulate`` module
|
||||
to be installed.
|
||||
"""
|
||||
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError:
|
||||
print("`summary` relies on the library `tabulate`, "
|
||||
"which could not be found on this machine. Run `pip "
|
||||
"install tabulate` to install the library.")
|
||||
|
||||
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
|
||||
|
||||
# Build up a list of summary information for each node
|
||||
node_summaries: List[List[Any]] = []
|
||||
|
||||
def mem_repr(mem: int) -> str:
|
||||
unit_divisor_map = {
|
||||
'kb': 1024,
|
||||
'mb': 1024**2,
|
||||
'gb': 1024**3,
|
||||
'tb': 1024**4,
|
||||
}
|
||||
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
|
||||
|
||||
def flops_repr(flop: int) -> str:
|
||||
return f"{flop:,} FLOPs"
|
||||
|
||||
for node in self.module.graph.nodes:
|
||||
node: Node
|
||||
node_summaries.append([
|
||||
node.op,
|
||||
str(node),
|
||||
flops_repr(node.meta['fwd_flop']),
|
||||
flops_repr(node.meta['bwd_flop']),
|
||||
node.meta['save_fwd_in'],
|
||||
mem_repr(node.meta['fwd_mem_out']),
|
||||
mem_repr(node.meta['fwd_mem_tmp']),
|
||||
mem_repr(node.meta['bwd_mem_out']),
|
||||
mem_repr(node.meta['bwd_mem_tmp']),
|
||||
])
|
||||
|
||||
# Use the ``tabulate`` library to create a well-formatted table
|
||||
# presenting our summary information
|
||||
headers: List[str] = [
|
||||
'Op type',
|
||||
'Op',
|
||||
'Forward FLOPs',
|
||||
'Backward FLOPs',
|
||||
'SAVE_FWD_IN',
|
||||
'FWD_OUT',
|
||||
'FWD_TMP',
|
||||
'BWD_OUT',
|
||||
'BWD_TMP',
|
||||
]
|
||||
|
||||
return tabulate(node_summaries, headers=headers, stralign='right')
|
||||
|
|
Loading…
Reference in New Issue