[fx] metainfo_trace as an API. (#1873)

* [fx] metainfo_trace as an API.

* [fx] add return.
pull/1797/head
Super Daniel 2022-11-10 20:58:37 +08:00 committed by GitHub
parent 6d559ea614
commit 448248b27c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 4 deletions

View File

@ -1,4 +1,4 @@
from ._compatibility import compatibility, is_compatible_with_meta
from .graph_module import ColoGraphModule
from .passes import MetaInfoProp
from .passes import MetaInfoProp, metainfo_trace
from .tracer import ColoTracer, meta_trace, symbolic_trace

View File

@ -1,4 +1,4 @@
from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass
from .meta_info_prop import MetaInfoProp
from .concrete_info_prop import ConcreteInfoProp
from .meta_info_prop import MetaInfoProp, metainfo_trace
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass

View File

@ -6,7 +6,7 @@ import torch.fx
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_map
from colossalai.fx._compatibility import compatibility
from colossalai.fx._compatibility import compatibility, is_compatible_with_meta
from colossalai.fx.profiler import (
GraphInfo,
activation_size,
@ -315,3 +315,38 @@ class MetaInfoProp(torch.fx.Interpreter):
]
return tabulate(node_summaries, headers=headers, stralign='right')
def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None:
"""
MetaInfo tracing API
Given a ``GraphModule`` and a sample input, this API will trace the MetaInfo of a single training cycle,
and annotate them on ``gm.graph``.
Uses:
>>> model = ...
>>> gm = symbolic_trace(model)
>>> args = ... # sample input to the ``GraphModule``
>>> metainfo_trace(gm, *args)
Args:
gm (torch.fx.GraphModule): The ``GraphModule`` to be annotated with MetaInfo.
verbose (bool, optional): Whether to show ``MetaInfoProp.summary()`. Defaults to False.
unit (str, optional): The unit of memory. Defaults to "MB".
Returns:
torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo.
"""
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
interp = MetaInfoProp(gm.to(device))
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
args = tree_map(lambda x: MetaTensor(x, fake_device=device), args)
kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs)
interp.propagate(*args, **kwargs)
if verbose:
interp.summary(unit)
gm.to('cpu')
del interp
return gm