From 441d584e4ac92c9b3c03b61bc936f8df3e99e434 Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Tue, 8 Nov 2022 13:59:20 +0800 Subject: [PATCH] [fx] add a symbolic_trace api. (#1812) * [fx] add a symbolic_trace api. * [fx] fix import errors. --- colossalai/fx/__init__.py | 8 +-- colossalai/fx/tracer/__init__.py | 1 + colossalai/fx/tracer/_symbolic_trace.py | 58 +++++++++++++++++++ .../{utils.py => hf_tracer_utils.py} | 9 +-- .../test_hf_model/test_hf_albert.py | 2 +- .../test_tracer/test_hf_model/test_hf_bert.py | 2 +- .../test_hf_model/test_hf_diffuser.py | 19 ++---- .../test_tracer/test_hf_model/test_hf_gpt.py | 2 +- .../test_tracer/test_hf_model/test_hf_opt.py | 2 +- .../test_tracer/test_hf_model/test_hf_t5.py | 2 +- .../test_timm_model/test_timm_model.py | 15 ++--- .../test_torchaudio_model/torchaudio_utils.py | 8 +-- .../test_torchrec_model/test_deepfm_model.py | 13 +---- .../test_torchrec_model/test_dlrm_model.py | 12 +--- .../test_torchvision_model.py | 10 +--- 15 files changed, 90 insertions(+), 73 deletions(-) create mode 100644 colossalai/fx/tracer/_symbolic_trace.py rename tests/test_fx/test_tracer/test_hf_model/{utils.py => hf_tracer_utils.py} (77%) diff --git a/colossalai/fx/__init__.py b/colossalai/fx/__init__.py index 5693f3eac..6bbbf0ebf 100644 --- a/colossalai/fx/__init__.py +++ b/colossalai/fx/__init__.py @@ -1,4 +1,4 @@ -from ._compatibility import compatibility, is_compatible_with_meta -from .graph_module import ColoGraphModule -from .passes import MetaInfoProp -from .tracer import ColoTracer, meta_trace +from ._compatibility import compatibility, is_compatible_with_meta +from .graph_module import ColoGraphModule +from .passes import MetaInfoProp +from .tracer import ColoTracer, meta_trace, symbolic_trace diff --git a/colossalai/fx/tracer/__init__.py b/colossalai/fx/tracer/__init__.py index bf88cc1c1..590555ce3 100644 --- a/colossalai/fx/tracer/__init__.py +++ b/colossalai/fx/tracer/__init__.py @@ -1,4 +1,5 @@ from colossalai.fx.tracer.meta_patch.patched_function.python_ops import operator_getitem from ._meta_trace import meta_trace +from ._symbolic_trace import symbolic_trace from .tracer import ColoTracer diff --git a/colossalai/fx/tracer/_symbolic_trace.py b/colossalai/fx/tracer/_symbolic_trace.py new file mode 100644 index 000000000..39da62473 --- /dev/null +++ b/colossalai/fx/tracer/_symbolic_trace.py @@ -0,0 +1,58 @@ +from typing import Any, Callable, Dict, Optional, Union + +import torch + +from colossalai.fx import ColoGraphModule +from colossalai.fx._compatibility import compatibility + +from .tracer import ColoTracer + + +@compatibility(is_backward_compatible=True) +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, + meta_args: Optional[Dict[str, Any]] = None, +) -> ColoGraphModule: + """ + Symbolic tracing API + + Given an ``nn.Module`` or function instance ``root``, this function will return a ``ColoGraphModule`` + constructed by recording operations seen while tracing through ``root``. + + With ``meta_args`` and ``concrete_args``, we can trace the model that are untraceable subject to control flow. + If specified using ``meta_args`` only, the tracing can be done ahead of time. + + Note that both ``meta_args`` and ``concrete_args`` are kwargs, which contains the key of the argument's names + and the value of the argument's values. + + Uses: + >>> model = ... + + # if this works + >>> gm = symbolic_trace(model) + + # else try this + >>> gm = symbolic_trace(model, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')}) + + # else try this + >>> gm = symbolic_trace(model, concrete_args={'x': torch.rand(1, 3, 224, 224)}) + + Args: + root (Union[torch.nn.Module, Callable[..., Any]]): Module or function to be traced and converted + into a Graph representation. + concrete_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized. Defaults to None. + meta_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized, special for ``ColoTracer``. + Defaults to None. + + Returns: + ColoGraphModule: A ``ColoGraphModule`` created from the recorded operations from ``root``. + + Warnings: + This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team. + + """ + tracer = ColoTracer() + graph = tracer.trace(root, concrete_args, meta_args) + name = (root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__) + return ColoGraphModule(tracer.root, graph, name) diff --git a/tests/test_fx/test_tracer/test_hf_model/utils.py b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py similarity index 77% rename from tests/test_fx/test_tracer/test_hf_model/utils.py rename to tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py index fb0702455..6d93fe040 100644 --- a/tests/test_fx/test_tracer/test_hf_model/utils.py +++ b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py @@ -3,24 +3,19 @@ from numpy import isin from torch.fx import GraphModule from torch.utils._pytree import tree_flatten -from colossalai.fx import ColoTracer +from colossalai.fx import symbolic_trace def trace_model_and_compare_output(model, data_gen): # must turn on eval mode to ensure the output is consistent model.eval() - # make sure that the model is traceable - tracer = ColoTracer() - try: kwargs = data_gen() meta_args = {k: v.to('meta') for k, v in kwargs.items()} - graph = tracer.trace(root=model, meta_args=meta_args) + gm = symbolic_trace(model, meta_args=meta_args) except Exception as e: raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() # run forward inputs = data_gen() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py index 5837340fa..9c36b0c9c 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py @@ -1,7 +1,7 @@ import pytest import torch import transformers -from utils import trace_model_and_compare_output +from hf_tracer_utils import trace_model_and_compare_output BATCH_SIZE = 2 SEQ_LENGTH = 16 diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index 1a66b1151..62273e2d5 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -1,7 +1,7 @@ import pytest import torch import transformers -from utils import trace_model_and_compare_output +from hf_tracer_utils import trace_model_and_compare_output BATCH_SIZE = 2 SEQ_LENGTH = 16 diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py index e02885e38..04e874bec 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py @@ -1,10 +1,9 @@ import pytest import torch -from torch.fx import GraphModule -from utils import trace_model_and_compare_output - import transformers -from colossalai.fx import ColoTracer +from hf_tracer_utils import trace_model_and_compare_output + +from colossalai.fx import symbolic_trace try: import diffusers @@ -32,11 +31,7 @@ def test_vae(): model = model_cls() sample = torch.zeros(LATENTS_SHAPE) - tracer = ColoTracer() - graph = tracer.trace(root=model) - - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() + gm = symbolic_trace(model) model.eval() gm.eval() @@ -98,11 +93,7 @@ def test_unet(): model = model_cls() sample = torch.zeros(LATENTS_SHAPE) - tracer = ColoTracer() - graph = tracer.trace(root=model) - - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() + gm = symbolic_trace(model) model.eval() gm.eval() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index ae2e752f9..269bc26f3 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -1,7 +1,7 @@ import pytest import torch import transformers -from utils import trace_model_and_compare_output +from hf_tracer_utils import trace_model_and_compare_output BATCH_SIZE = 1 SEQ_LENGTH = 16 diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index c39e97a16..06260176e 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -1,7 +1,7 @@ import pytest import torch import transformers -from utils import trace_model_and_compare_output +from hf_tracer_utils import trace_model_and_compare_output BATCH_SIZE = 1 SEQ_LENGTH = 16 diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index b6749c828..71e782fdd 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -1,7 +1,7 @@ import pytest import torch import transformers -from utils import trace_model_and_compare_output +from hf_tracer_utils import trace_model_and_compare_output BATCH_SIZE = 1 SEQ_LENGTH = 16 diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index 44b605a4e..28ec3d825 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -1,12 +1,11 @@ import pytest import timm.models as tm import torch -from torch.fx import GraphModule -from colossalai.fx import ColoTracer +from colossalai.fx import symbolic_trace -def trace_and_compare(model_cls, tracer, data, meta_args=None): +def trace_and_compare(model_cls, data, meta_args=None): # trace model = model_cls() @@ -15,9 +14,7 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None): # without this statement, the torch.nn.functional.batch_norm will always be in training mode model.eval() - graph = tracer.trace(root=model, meta_args=meta_args) - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() + gm = symbolic_trace(model, meta_args=meta_args) # run forward with torch.no_grad(): @@ -49,11 +46,10 @@ def test_timm_models_without_control_flow(): tm.deit_base_distilled_patch16_224, ] - tracer = ColoTracer() data = torch.rand(2, 3, 224, 224) for model_cls in MODEL_LIST: - trace_and_compare(model_cls, tracer, data) + trace_and_compare(model_cls, data) def test_timm_models_with_control_flow(): @@ -64,13 +60,12 @@ def test_timm_models_with_control_flow(): tm.swin_transformer.swin_base_patch4_window7_224 ] - tracer = ColoTracer() data = torch.rand(2, 3, 224, 224) meta_args = {'x': data.to('meta')} for model_cls in MODEL_LIST_WITH_CONTROL_FLOW: - trace_and_compare(model_cls, tracer, data, meta_args) + trace_and_compare(model_cls, data, meta_args) if __name__ == '__main__': diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py index f40cad04d..702c5f8f6 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py @@ -1,20 +1,16 @@ import torch -from torch.fx import GraphModule, Tracer -from colossalai.fx import ColoTracer +from colossalai.fx import symbolic_trace def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False): data = data_gen() concrete_args = data if need_concrete else {} meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {} - tracer = ColoTracer() model.eval() - graph = tracer.trace(root=model, concrete_args=concrete_args, meta_args=meta_args) - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() + gm = symbolic_trace(model, concrete_args=concrete_args, meta_args=meta_args) with torch.no_grad(): non_fx_out = model(**data) diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py index d2efc3c45..dbe8a62e7 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -1,9 +1,7 @@ import pytest import torch -from colossalai.fx.tracer import meta_patch -from colossalai.fx.tracer.meta_patch.patched_function import python_ops -from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.fx import symbolic_trace try: from torchrec.models import deepfm @@ -14,8 +12,6 @@ try: except ImportError: NOT_TORCHREC = True -from torch.fx import GraphModule - BATCH = 2 SHAPE = 10 @@ -43,9 +39,6 @@ def test_torchrec_deepfm_models(): # Dense Features features = torch.rand((BATCH, SHAPE)) - # Tracer - tracer = ColoTracer() - for model_cls in MODEL_LIST: # Initializing model if model_cls == deepfm.DenseArch: @@ -60,9 +53,7 @@ def test_torchrec_deepfm_models(): model = model_cls(ebc) # Setup GraphModule - graph = tracer.trace(model) - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() + gm = symbolic_trace(model) model.eval() gm.eval() diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py index 4050c7d3c..2f9fd8fe5 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -1,6 +1,6 @@ import torch -from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.fx import symbolic_trace try: from torchrec.models import dlrm @@ -12,7 +12,6 @@ except ImportError: NOT_TORCHREC = True import pytest -from torch.fx import GraphModule BATCH = 2 SHAPE = 10 @@ -51,8 +50,6 @@ def test_torchrec_dlrm_models(): # Sparse Features sparse_features = torch.rand((BATCH, len(keys), SHAPE)) - # Tracer - tracer = ColoTracer() for model_cls in MODEL_LIST: # Initializing model @@ -77,12 +74,9 @@ def test_torchrec_dlrm_models(): # Setup GraphModule if model_cls == dlrm.InteractionV2Arch: concrete_args = {"dense_features": dense_features, "sparse_features": sparse_features} - graph = tracer.trace(model, concrete_args=concrete_args) + gm = symbolic_trace(model, concrete_args=concrete_args) else: - graph = tracer.trace(model) - - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() + gm = symbolic_trace(model) model.eval() gm.eval() diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py index 046a0dabe..2a6c6ae16 100644 --- a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py +++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py @@ -2,8 +2,8 @@ import torch import torchvision import torchvision.models as tm from packaging import version -from colossalai.fx import ColoTracer -from torch.fx import GraphModule + +from colossalai.fx import symbolic_trace def test_torchvision_models(): @@ -20,7 +20,6 @@ def test_torchvision_models(): torch.backends.cudnn.deterministic = True - tracer = ColoTracer() data = torch.rand(2, 3, 224, 224) for model_cls in MODEL_LIST: @@ -30,10 +29,7 @@ def test_torchvision_models(): else: model = model_cls() - graph = tracer.trace(root=model) - - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() + gm = symbolic_trace(model) model.eval() gm.eval()