mirror of https://github.com/hpcaitech/ColossalAI
[fx/profiler] provide a table of summary. (#1634)
* [fx/profiling] provide summary for MetaInfoProp. * [fx/profiler] provide a table of summary. * [fx] optimize table repr.pull/1643/head
parent
95c35f73bd
commit
04bbabeea8
|
@ -1,2 +1,3 @@
|
||||||
from .tracer import ColoTracer, meta_trace
|
from .tracer import ColoTracer, meta_trace
|
||||||
from .graph_module import ColoGraphModule
|
from .graph_module import ColoGraphModule
|
||||||
|
from .passes import MetaInfoProp
|
||||||
|
|
|
@ -1,2 +1,3 @@
|
||||||
from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
|
from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
|
||||||
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass
|
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass
|
||||||
|
from .meta_info_prop import MetaInfoProp
|
||||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch.fx.node import Node, Argument, Target
|
from torch.fx.node import Node, Argument, Target
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
from typing import Any, Tuple, NamedTuple, Dict
|
from typing import Any, List, Tuple, NamedTuple, Dict
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
|
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
|
||||||
|
|
||||||
|
@ -48,28 +48,33 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
Usage:
|
Usage:
|
||||||
BATCH_SIZE = 2
|
BATCH_SIZE = 2
|
||||||
DIM_IN = 4
|
DIM_IN = 4
|
||||||
|
DIM_HIDDEN = 16
|
||||||
DIM_OUT = 16
|
DIM_OUT = 16
|
||||||
model = torch.nn.Linear(DIM_IN, DIM_OUT)
|
model = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
|
||||||
|
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
|
||||||
|
)
|
||||||
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
|
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
|
||||||
orig_output = model(input_sample)
|
|
||||||
gm = symbolic_trace(model)
|
gm = symbolic_trace(model)
|
||||||
MetaInfoProp(gm).run(input_sample)
|
interp = MetaInfoProp(gm)
|
||||||
|
interp.run(input_sample)
|
||||||
for node in gm.graph.nodes:
|
print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB
|
||||||
print(node.name, node.meta['tensor_meta'].dtype,
|
|
||||||
node.meta['tensor_meta'].shape, node.meta['tensor_meta'].numel)
|
|
||||||
|
|
||||||
# output of above code is
|
# output of above code is
|
||||||
# input_1 torch.float32 torch.Size([2, 4]) 8
|
Op type Op Forward FLOPs Backward FLOPs SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP
|
||||||
# weight torch.float32 torch.Size([16, 4]) 64
|
----------- ------- --------------- ---------------- ------------- --------- --------- --------- ---------
|
||||||
# bias torch.float32 torch.Size([16]) 16
|
placeholder input_1 0 FLOPs 0 FLOPs False 0.00 KB 0.00 KB 0.00 KB 0.00 KB
|
||||||
# linear torch.float32 torch.Size([2, 16]) 32
|
call_module _0 128 FLOPs 288 FLOPs True 0.12 KB 0.00 KB 0.34 KB 0.00 KB
|
||||||
# output torch.float32 torch.Size([2, 16]) 32
|
call_module _1 512 FLOPs 1,056 FLOPs True 0.12 KB 0.00 KB 1.19 KB 0.00 KB
|
||||||
|
output output 0 FLOPs 0 FLOPs True 0.00 KB 0.00 KB 0.00 KB 0.00 KB
|
||||||
Args:
|
Args:
|
||||||
module (GraphModule): The module to be executed
|
module (GraphModule): The module to be executed
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_is_proped: bool = False
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def run_node(self, n: Node) -> Any:
|
def run_node(self, n: Node) -> Any:
|
||||||
"""
|
"""
|
||||||
|
@ -84,6 +89,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
Returns:
|
Returns:
|
||||||
Any: The result of executing ``n``
|
Any: The result of executing ``n``
|
||||||
"""
|
"""
|
||||||
|
self._is_proped = True
|
||||||
result, meta_info = super().run_node(n)
|
result, meta_info = super().run_node(n)
|
||||||
|
|
||||||
def extract_tensor_meta(obj):
|
def extract_tensor_meta(obj):
|
||||||
|
@ -236,3 +242,64 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
Any: The value returned from executing the Module
|
Any: The value returned from executing the Module
|
||||||
"""
|
"""
|
||||||
return super().run(*args)
|
return super().run(*args)
|
||||||
|
|
||||||
|
def summary(self, unit: str = 'MB') -> str:
|
||||||
|
"""
|
||||||
|
Summarizes the memory and FLOPs statistics of the `GraphModule` in
|
||||||
|
tabular format. Note that this API requires the ``tabulate`` module
|
||||||
|
to be installed.
|
||||||
|
"""
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
|
||||||
|
try:
|
||||||
|
from tabulate import tabulate
|
||||||
|
except ImportError:
|
||||||
|
print("`summary` relies on the library `tabulate`, "
|
||||||
|
"which could not be found on this machine. Run `pip "
|
||||||
|
"install tabulate` to install the library.")
|
||||||
|
|
||||||
|
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
|
||||||
|
|
||||||
|
# Build up a list of summary information for each node
|
||||||
|
node_summaries: List[List[Any]] = []
|
||||||
|
|
||||||
|
def mem_repr(mem: int) -> str:
|
||||||
|
unit_divisor_map = {
|
||||||
|
'kb': 1024,
|
||||||
|
'mb': 1024**2,
|
||||||
|
'gb': 1024**3,
|
||||||
|
'tb': 1024**4,
|
||||||
|
}
|
||||||
|
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
|
||||||
|
|
||||||
|
def flops_repr(flop: int) -> str:
|
||||||
|
return f"{flop:,} FLOPs"
|
||||||
|
|
||||||
|
for node in self.module.graph.nodes:
|
||||||
|
node: Node
|
||||||
|
node_summaries.append([
|
||||||
|
node.op,
|
||||||
|
str(node),
|
||||||
|
flops_repr(node.meta['fwd_flop']),
|
||||||
|
flops_repr(node.meta['bwd_flop']),
|
||||||
|
node.meta['save_fwd_in'],
|
||||||
|
mem_repr(node.meta['fwd_mem_out']),
|
||||||
|
mem_repr(node.meta['fwd_mem_tmp']),
|
||||||
|
mem_repr(node.meta['bwd_mem_out']),
|
||||||
|
mem_repr(node.meta['bwd_mem_tmp']),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Use the ``tabulate`` library to create a well-formatted table
|
||||||
|
# presenting our summary information
|
||||||
|
headers: List[str] = [
|
||||||
|
'Op type',
|
||||||
|
'Op',
|
||||||
|
'Forward FLOPs',
|
||||||
|
'Backward FLOPs',
|
||||||
|
'SAVE_FWD_IN',
|
||||||
|
'FWD_OUT',
|
||||||
|
'FWD_TMP',
|
||||||
|
'BWD_OUT',
|
||||||
|
'BWD_TMP',
|
||||||
|
]
|
||||||
|
|
||||||
|
return tabulate(node_summaries, headers=headers, stralign='right')
|
||||||
|
|
Loading…
Reference in New Issue