mirror of https://github.com/hpcaitech/ColossalAI
[fx] metainfo_trace as an API. (#1873)
* [fx] metainfo_trace as an API. * [fx] add return.pull/1797/head
parent
6d559ea614
commit
448248b27c
|
@ -1,4 +1,4 @@
|
|||
from ._compatibility import compatibility, is_compatible_with_meta
|
||||
from .graph_module import ColoGraphModule
|
||||
from .passes import MetaInfoProp
|
||||
from .passes import MetaInfoProp, metainfo_trace
|
||||
from .tracer import ColoTracer, meta_trace, symbolic_trace
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
|
||||
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass
|
||||
from .meta_info_prop import MetaInfoProp
|
||||
from .concrete_info_prop import ConcreteInfoProp
|
||||
from .meta_info_prop import MetaInfoProp, metainfo_trace
|
||||
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch.fx
|
|||
from torch.fx.node import Argument, Node, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx._compatibility import compatibility, is_compatible_with_meta
|
||||
from colossalai.fx.profiler import (
|
||||
GraphInfo,
|
||||
activation_size,
|
||||
|
@ -315,3 +315,38 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
]
|
||||
|
||||
return tabulate(node_summaries, headers=headers, stralign='right')
|
||||
|
||||
|
||||
def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None:
|
||||
"""
|
||||
MetaInfo tracing API
|
||||
|
||||
Given a ``GraphModule`` and a sample input, this API will trace the MetaInfo of a single training cycle,
|
||||
and annotate them on ``gm.graph``.
|
||||
|
||||
Uses:
|
||||
>>> model = ...
|
||||
>>> gm = symbolic_trace(model)
|
||||
>>> args = ... # sample input to the ``GraphModule``
|
||||
>>> metainfo_trace(gm, *args)
|
||||
|
||||
Args:
|
||||
gm (torch.fx.GraphModule): The ``GraphModule`` to be annotated with MetaInfo.
|
||||
verbose (bool, optional): Whether to show ``MetaInfoProp.summary()`. Defaults to False.
|
||||
unit (str, optional): The unit of memory. Defaults to "MB".
|
||||
|
||||
Returns:
|
||||
torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo.
|
||||
"""
|
||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
interp = MetaInfoProp(gm.to(device))
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
args = tree_map(lambda x: MetaTensor(x, fake_device=device), args)
|
||||
kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs)
|
||||
interp.propagate(*args, **kwargs)
|
||||
if verbose:
|
||||
interp.summary(unit)
|
||||
gm.to('cpu')
|
||||
del interp
|
||||
return gm
|
||||
|
|
Loading…
Reference in New Issue