From 448248b27cadc082bc58da564303b455000e3374 Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Thu, 10 Nov 2022 20:58:37 +0800 Subject: [PATCH] [fx] metainfo_trace as an API. (#1873) * [fx] metainfo_trace as an API. * [fx] add return. --- colossalai/fx/__init__.py | 2 +- colossalai/fx/passes/__init__.py | 4 +-- colossalai/fx/passes/meta_info_prop.py | 37 +++++++++++++++++++++++++- 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/colossalai/fx/__init__.py b/colossalai/fx/__init__.py index 6bbbf0ebf..d39fa5799 100644 --- a/colossalai/fx/__init__.py +++ b/colossalai/fx/__init__.py @@ -1,4 +1,4 @@ from ._compatibility import compatibility, is_compatible_with_meta from .graph_module import ColoGraphModule -from .passes import MetaInfoProp +from .passes import MetaInfoProp, metainfo_trace from .tracer import ColoTracer, meta_trace, symbolic_trace diff --git a/colossalai/fx/passes/__init__.py b/colossalai/fx/passes/__init__.py index 43ac14ec4..6f948cb2d 100644 --- a/colossalai/fx/passes/__init__.py +++ b/colossalai/fx/passes/__init__.py @@ -1,4 +1,4 @@ from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass -from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass -from .meta_info_prop import MetaInfoProp from .concrete_info_prop import ConcreteInfoProp +from .meta_info_prop import MetaInfoProp, metainfo_trace +from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 90009b22b..711439955 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -6,7 +6,7 @@ import torch.fx from torch.fx.node import Argument, Node, Target from torch.utils._pytree import tree_map -from colossalai.fx._compatibility import compatibility +from colossalai.fx._compatibility import compatibility, is_compatible_with_meta from colossalai.fx.profiler import ( GraphInfo, activation_size, @@ -315,3 +315,38 @@ class MetaInfoProp(torch.fx.Interpreter): ] return tabulate(node_summaries, headers=headers, stralign='right') + + +def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None: + """ + MetaInfo tracing API + + Given a ``GraphModule`` and a sample input, this API will trace the MetaInfo of a single training cycle, + and annotate them on ``gm.graph``. + + Uses: + >>> model = ... + >>> gm = symbolic_trace(model) + >>> args = ... # sample input to the ``GraphModule`` + >>> metainfo_trace(gm, *args) + + Args: + gm (torch.fx.GraphModule): The ``GraphModule`` to be annotated with MetaInfo. + verbose (bool, optional): Whether to show ``MetaInfoProp.summary()`. Defaults to False. + unit (str, optional): The unit of memory. Defaults to "MB". + + Returns: + torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo. + """ + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + interp = MetaInfoProp(gm.to(device)) + if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + args = tree_map(lambda x: MetaTensor(x, fake_device=device), args) + kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs) + interp.propagate(*args, **kwargs) + if verbose: + interp.summary(unit) + gm.to('cpu') + del interp + return gm