[fx] add a symbolic_trace api. (#1812)

* [fx] add a symbolic_trace api.

* [fx] fix import errors.
pull/1817/head
Super Daniel 2022-11-08 13:59:20 +08:00 committed by GitHub
parent 350ccc0481
commit 441d584e4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 90 additions and 73 deletions

View File

@ -1,4 +1,4 @@
from ._compatibility import compatibility, is_compatible_with_meta from ._compatibility import compatibility, is_compatible_with_meta
from .graph_module import ColoGraphModule from .graph_module import ColoGraphModule
from .passes import MetaInfoProp from .passes import MetaInfoProp
from .tracer import ColoTracer, meta_trace from .tracer import ColoTracer, meta_trace, symbolic_trace

View File

@ -1,4 +1,5 @@
from colossalai.fx.tracer.meta_patch.patched_function.python_ops import operator_getitem from colossalai.fx.tracer.meta_patch.patched_function.python_ops import operator_getitem
from ._meta_trace import meta_trace from ._meta_trace import meta_trace
from ._symbolic_trace import symbolic_trace
from .tracer import ColoTracer from .tracer import ColoTracer

View File

@ -0,0 +1,58 @@
from typing import Any, Callable, Dict, Optional, Union
import torch
from colossalai.fx import ColoGraphModule
from colossalai.fx._compatibility import compatibility
from .tracer import ColoTracer
@compatibility(is_backward_compatible=True)
def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None,
) -> ColoGraphModule:
"""
Symbolic tracing API
Given an ``nn.Module`` or function instance ``root``, this function will return a ``ColoGraphModule``
constructed by recording operations seen while tracing through ``root``.
With ``meta_args`` and ``concrete_args``, we can trace the model that are untraceable subject to control flow.
If specified using ``meta_args`` only, the tracing can be done ahead of time.
Note that both ``meta_args`` and ``concrete_args`` are kwargs, which contains the key of the argument's names
and the value of the argument's values.
Uses:
>>> model = ...
# if this works
>>> gm = symbolic_trace(model)
# else try this
>>> gm = symbolic_trace(model, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')})
# else try this
>>> gm = symbolic_trace(model, concrete_args={'x': torch.rand(1, 3, 224, 224)})
Args:
root (Union[torch.nn.Module, Callable[..., Any]]): Module or function to be traced and converted
into a Graph representation.
concrete_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized. Defaults to None.
meta_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized, special for ``ColoTracer``.
Defaults to None.
Returns:
ColoGraphModule: A ``ColoGraphModule`` created from the recorded operations from ``root``.
Warnings:
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
"""
tracer = ColoTracer()
graph = tracer.trace(root, concrete_args, meta_args)
name = (root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__)
return ColoGraphModule(tracer.root, graph, name)

View File

@ -3,24 +3,19 @@ from numpy import isin
from torch.fx import GraphModule from torch.fx import GraphModule
from torch.utils._pytree import tree_flatten from torch.utils._pytree import tree_flatten
from colossalai.fx import ColoTracer from colossalai.fx import symbolic_trace
def trace_model_and_compare_output(model, data_gen): def trace_model_and_compare_output(model, data_gen):
# must turn on eval mode to ensure the output is consistent # must turn on eval mode to ensure the output is consistent
model.eval() model.eval()
# make sure that the model is traceable
tracer = ColoTracer()
try: try:
kwargs = data_gen() kwargs = data_gen()
meta_args = {k: v.to('meta') for k, v in kwargs.items()} meta_args = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args) gm = symbolic_trace(model, meta_args=meta_args)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# run forward # run forward
inputs = data_gen() inputs = data_gen()

View File

@ -1,7 +1,7 @@
import pytest import pytest
import torch import torch
import transformers import transformers
from utils import trace_model_and_compare_output from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 2 BATCH_SIZE = 2
SEQ_LENGTH = 16 SEQ_LENGTH = 16

View File

@ -1,7 +1,7 @@
import pytest import pytest
import torch import torch
import transformers import transformers
from utils import trace_model_and_compare_output from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 2 BATCH_SIZE = 2
SEQ_LENGTH = 16 SEQ_LENGTH = 16

View File

@ -1,10 +1,9 @@
import pytest import pytest
import torch import torch
from torch.fx import GraphModule
from utils import trace_model_and_compare_output
import transformers import transformers
from colossalai.fx import ColoTracer from hf_tracer_utils import trace_model_and_compare_output
from colossalai.fx import symbolic_trace
try: try:
import diffusers import diffusers
@ -32,11 +31,7 @@ def test_vae():
model = model_cls() model = model_cls()
sample = torch.zeros(LATENTS_SHAPE) sample = torch.zeros(LATENTS_SHAPE)
tracer = ColoTracer() gm = symbolic_trace(model)
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
model.eval() model.eval()
gm.eval() gm.eval()
@ -98,11 +93,7 @@ def test_unet():
model = model_cls() model = model_cls()
sample = torch.zeros(LATENTS_SHAPE) sample = torch.zeros(LATENTS_SHAPE)
tracer = ColoTracer() gm = symbolic_trace(model)
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
model.eval() model.eval()
gm.eval() gm.eval()

View File

@ -1,7 +1,7 @@
import pytest import pytest
import torch import torch
import transformers import transformers
from utils import trace_model_and_compare_output from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 1 BATCH_SIZE = 1
SEQ_LENGTH = 16 SEQ_LENGTH = 16

View File

@ -1,7 +1,7 @@
import pytest import pytest
import torch import torch
import transformers import transformers
from utils import trace_model_and_compare_output from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 1 BATCH_SIZE = 1
SEQ_LENGTH = 16 SEQ_LENGTH = 16

View File

@ -1,7 +1,7 @@
import pytest import pytest
import torch import torch
import transformers import transformers
from utils import trace_model_and_compare_output from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 1 BATCH_SIZE = 1
SEQ_LENGTH = 16 SEQ_LENGTH = 16

View File

@ -1,12 +1,11 @@
import pytest import pytest
import timm.models as tm import timm.models as tm
import torch import torch
from torch.fx import GraphModule
from colossalai.fx import ColoTracer from colossalai.fx import symbolic_trace
def trace_and_compare(model_cls, tracer, data, meta_args=None): def trace_and_compare(model_cls, data, meta_args=None):
# trace # trace
model = model_cls() model = model_cls()
@ -15,9 +14,7 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None):
# without this statement, the torch.nn.functional.batch_norm will always be in training mode # without this statement, the torch.nn.functional.batch_norm will always be in training mode
model.eval() model.eval()
graph = tracer.trace(root=model, meta_args=meta_args) gm = symbolic_trace(model, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# run forward # run forward
with torch.no_grad(): with torch.no_grad():
@ -49,11 +46,10 @@ def test_timm_models_without_control_flow():
tm.deit_base_distilled_patch16_224, tm.deit_base_distilled_patch16_224,
] ]
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224) data = torch.rand(2, 3, 224, 224)
for model_cls in MODEL_LIST: for model_cls in MODEL_LIST:
trace_and_compare(model_cls, tracer, data) trace_and_compare(model_cls, data)
def test_timm_models_with_control_flow(): def test_timm_models_with_control_flow():
@ -64,13 +60,12 @@ def test_timm_models_with_control_flow():
tm.swin_transformer.swin_base_patch4_window7_224 tm.swin_transformer.swin_base_patch4_window7_224
] ]
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224) data = torch.rand(2, 3, 224, 224)
meta_args = {'x': data.to('meta')} meta_args = {'x': data.to('meta')}
for model_cls in MODEL_LIST_WITH_CONTROL_FLOW: for model_cls in MODEL_LIST_WITH_CONTROL_FLOW:
trace_and_compare(model_cls, tracer, data, meta_args) trace_and_compare(model_cls, data, meta_args)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,20 +1,16 @@
import torch import torch
from torch.fx import GraphModule, Tracer
from colossalai.fx import ColoTracer from colossalai.fx import symbolic_trace
def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False): def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False):
data = data_gen() data = data_gen()
concrete_args = data if need_concrete else {} concrete_args = data if need_concrete else {}
meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {} meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {}
tracer = ColoTracer()
model.eval() model.eval()
graph = tracer.trace(root=model, concrete_args=concrete_args, meta_args=meta_args) gm = symbolic_trace(model, concrete_args=concrete_args, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
with torch.no_grad(): with torch.no_grad():
non_fx_out = model(**data) non_fx_out = model(**data)

View File

@ -1,9 +1,7 @@
import pytest import pytest
import torch import torch
from colossalai.fx.tracer import meta_patch from colossalai.fx import symbolic_trace
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
from colossalai.fx.tracer.tracer import ColoTracer
try: try:
from torchrec.models import deepfm from torchrec.models import deepfm
@ -14,8 +12,6 @@ try:
except ImportError: except ImportError:
NOT_TORCHREC = True NOT_TORCHREC = True
from torch.fx import GraphModule
BATCH = 2 BATCH = 2
SHAPE = 10 SHAPE = 10
@ -43,9 +39,6 @@ def test_torchrec_deepfm_models():
# Dense Features # Dense Features
features = torch.rand((BATCH, SHAPE)) features = torch.rand((BATCH, SHAPE))
# Tracer
tracer = ColoTracer()
for model_cls in MODEL_LIST: for model_cls in MODEL_LIST:
# Initializing model # Initializing model
if model_cls == deepfm.DenseArch: if model_cls == deepfm.DenseArch:
@ -60,9 +53,7 @@ def test_torchrec_deepfm_models():
model = model_cls(ebc) model = model_cls(ebc)
# Setup GraphModule # Setup GraphModule
graph = tracer.trace(model) gm = symbolic_trace(model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
model.eval() model.eval()
gm.eval() gm.eval()

View File

@ -1,6 +1,6 @@
import torch import torch
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.fx import symbolic_trace
try: try:
from torchrec.models import dlrm from torchrec.models import dlrm
@ -12,7 +12,6 @@ except ImportError:
NOT_TORCHREC = True NOT_TORCHREC = True
import pytest import pytest
from torch.fx import GraphModule
BATCH = 2 BATCH = 2
SHAPE = 10 SHAPE = 10
@ -51,8 +50,6 @@ def test_torchrec_dlrm_models():
# Sparse Features # Sparse Features
sparse_features = torch.rand((BATCH, len(keys), SHAPE)) sparse_features = torch.rand((BATCH, len(keys), SHAPE))
# Tracer
tracer = ColoTracer()
for model_cls in MODEL_LIST: for model_cls in MODEL_LIST:
# Initializing model # Initializing model
@ -77,12 +74,9 @@ def test_torchrec_dlrm_models():
# Setup GraphModule # Setup GraphModule
if model_cls == dlrm.InteractionV2Arch: if model_cls == dlrm.InteractionV2Arch:
concrete_args = {"dense_features": dense_features, "sparse_features": sparse_features} concrete_args = {"dense_features": dense_features, "sparse_features": sparse_features}
graph = tracer.trace(model, concrete_args=concrete_args) gm = symbolic_trace(model, concrete_args=concrete_args)
else: else:
graph = tracer.trace(model) gm = symbolic_trace(model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
model.eval() model.eval()
gm.eval() gm.eval()

View File

@ -2,8 +2,8 @@ import torch
import torchvision import torchvision
import torchvision.models as tm import torchvision.models as tm
from packaging import version from packaging import version
from colossalai.fx import ColoTracer
from torch.fx import GraphModule from colossalai.fx import symbolic_trace
def test_torchvision_models(): def test_torchvision_models():
@ -20,7 +20,6 @@ def test_torchvision_models():
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224) data = torch.rand(2, 3, 224, 224)
for model_cls in MODEL_LIST: for model_cls in MODEL_LIST:
@ -30,10 +29,7 @@ def test_torchvision_models():
else: else:
model = model_cls() model = model_cls()
graph = tracer.trace(root=model) gm = symbolic_trace(model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
model.eval() model.eval()
gm.eval() gm.eval()