[fx] add a symbolic_trace api. (#1812)

* [fx] add a symbolic_trace api.

* [fx] fix import errors.
pull/1817/head
Super Daniel 2022-11-08 13:59:20 +08:00 committed by GitHub
parent 350ccc0481
commit 441d584e4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 90 additions and 73 deletions

View File

@ -1,4 +1,4 @@
from ._compatibility import compatibility, is_compatible_with_meta
from .graph_module import ColoGraphModule
from .passes import MetaInfoProp
from .tracer import ColoTracer, meta_trace
from ._compatibility import compatibility, is_compatible_with_meta
from .graph_module import ColoGraphModule
from .passes import MetaInfoProp
from .tracer import ColoTracer, meta_trace, symbolic_trace

View File

@ -1,4 +1,5 @@
from colossalai.fx.tracer.meta_patch.patched_function.python_ops import operator_getitem
from ._meta_trace import meta_trace
from ._symbolic_trace import symbolic_trace
from .tracer import ColoTracer

View File

@ -0,0 +1,58 @@
from typing import Any, Callable, Dict, Optional, Union
import torch
from colossalai.fx import ColoGraphModule
from colossalai.fx._compatibility import compatibility
from .tracer import ColoTracer
@compatibility(is_backward_compatible=True)
def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None,
) -> ColoGraphModule:
"""
Symbolic tracing API
Given an ``nn.Module`` or function instance ``root``, this function will return a ``ColoGraphModule``
constructed by recording operations seen while tracing through ``root``.
With ``meta_args`` and ``concrete_args``, we can trace the model that are untraceable subject to control flow.
If specified using ``meta_args`` only, the tracing can be done ahead of time.
Note that both ``meta_args`` and ``concrete_args`` are kwargs, which contains the key of the argument's names
and the value of the argument's values.
Uses:
>>> model = ...
# if this works
>>> gm = symbolic_trace(model)
# else try this
>>> gm = symbolic_trace(model, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')})
# else try this
>>> gm = symbolic_trace(model, concrete_args={'x': torch.rand(1, 3, 224, 224)})
Args:
root (Union[torch.nn.Module, Callable[..., Any]]): Module or function to be traced and converted
into a Graph representation.
concrete_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized. Defaults to None.
meta_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized, special for ``ColoTracer``.
Defaults to None.
Returns:
ColoGraphModule: A ``ColoGraphModule`` created from the recorded operations from ``root``.
Warnings:
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
"""
tracer = ColoTracer()
graph = tracer.trace(root, concrete_args, meta_args)
name = (root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__)
return ColoGraphModule(tracer.root, graph, name)

View File

@ -3,24 +3,19 @@ from numpy import isin
from torch.fx import GraphModule
from torch.utils._pytree import tree_flatten
from colossalai.fx import ColoTracer
from colossalai.fx import symbolic_trace
def trace_model_and_compare_output(model, data_gen):
# must turn on eval mode to ensure the output is consistent
model.eval()
# make sure that the model is traceable
tracer = ColoTracer()
try:
kwargs = data_gen()
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
gm = symbolic_trace(model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# run forward
inputs = data_gen()

View File

@ -1,7 +1,7 @@
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output
from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 2
SEQ_LENGTH = 16

View File

@ -1,7 +1,7 @@
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output
from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 2
SEQ_LENGTH = 16

View File

@ -1,10 +1,9 @@
import pytest
import torch
from torch.fx import GraphModule
from utils import trace_model_and_compare_output
import transformers
from colossalai.fx import ColoTracer
from hf_tracer_utils import trace_model_and_compare_output
from colossalai.fx import symbolic_trace
try:
import diffusers
@ -32,11 +31,7 @@ def test_vae():
model = model_cls()
sample = torch.zeros(LATENTS_SHAPE)
tracer = ColoTracer()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
gm = symbolic_trace(model)
model.eval()
gm.eval()
@ -98,11 +93,7 @@ def test_unet():
model = model_cls()
sample = torch.zeros(LATENTS_SHAPE)
tracer = ColoTracer()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
gm = symbolic_trace(model)
model.eval()
gm.eval()

View File

@ -1,7 +1,7 @@
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output
from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGTH = 16

View File

@ -1,7 +1,7 @@
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output
from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGTH = 16

View File

@ -1,7 +1,7 @@
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output
from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGTH = 16

View File

@ -1,12 +1,11 @@
import pytest
import timm.models as tm
import torch
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
from colossalai.fx import symbolic_trace
def trace_and_compare(model_cls, tracer, data, meta_args=None):
def trace_and_compare(model_cls, data, meta_args=None):
# trace
model = model_cls()
@ -15,9 +14,7 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None):
# without this statement, the torch.nn.functional.batch_norm will always be in training mode
model.eval()
graph = tracer.trace(root=model, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
gm = symbolic_trace(model, meta_args=meta_args)
# run forward
with torch.no_grad():
@ -49,11 +46,10 @@ def test_timm_models_without_control_flow():
tm.deit_base_distilled_patch16_224,
]
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224)
for model_cls in MODEL_LIST:
trace_and_compare(model_cls, tracer, data)
trace_and_compare(model_cls, data)
def test_timm_models_with_control_flow():
@ -64,13 +60,12 @@ def test_timm_models_with_control_flow():
tm.swin_transformer.swin_base_patch4_window7_224
]
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224)
meta_args = {'x': data.to('meta')}
for model_cls in MODEL_LIST_WITH_CONTROL_FLOW:
trace_and_compare(model_cls, tracer, data, meta_args)
trace_and_compare(model_cls, data, meta_args)
if __name__ == '__main__':

View File

@ -1,20 +1,16 @@
import torch
from torch.fx import GraphModule, Tracer
from colossalai.fx import ColoTracer
from colossalai.fx import symbolic_trace
def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False):
data = data_gen()
concrete_args = data if need_concrete else {}
meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {}
tracer = ColoTracer()
model.eval()
graph = tracer.trace(root=model, concrete_args=concrete_args, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
gm = symbolic_trace(model, concrete_args=concrete_args, meta_args=meta_args)
with torch.no_grad():
non_fx_out = model(**data)

View File

@ -1,9 +1,7 @@
import pytest
import torch
from colossalai.fx.tracer import meta_patch
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx import symbolic_trace
try:
from torchrec.models import deepfm
@ -14,8 +12,6 @@ try:
except ImportError:
NOT_TORCHREC = True
from torch.fx import GraphModule
BATCH = 2
SHAPE = 10
@ -43,9 +39,6 @@ def test_torchrec_deepfm_models():
# Dense Features
features = torch.rand((BATCH, SHAPE))
# Tracer
tracer = ColoTracer()
for model_cls in MODEL_LIST:
# Initializing model
if model_cls == deepfm.DenseArch:
@ -60,9 +53,7 @@ def test_torchrec_deepfm_models():
model = model_cls(ebc)
# Setup GraphModule
graph = tracer.trace(model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
gm = symbolic_trace(model)
model.eval()
gm.eval()

View File

@ -1,6 +1,6 @@
import torch
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx import symbolic_trace
try:
from torchrec.models import dlrm
@ -12,7 +12,6 @@ except ImportError:
NOT_TORCHREC = True
import pytest
from torch.fx import GraphModule
BATCH = 2
SHAPE = 10
@ -51,8 +50,6 @@ def test_torchrec_dlrm_models():
# Sparse Features
sparse_features = torch.rand((BATCH, len(keys), SHAPE))
# Tracer
tracer = ColoTracer()
for model_cls in MODEL_LIST:
# Initializing model
@ -77,12 +74,9 @@ def test_torchrec_dlrm_models():
# Setup GraphModule
if model_cls == dlrm.InteractionV2Arch:
concrete_args = {"dense_features": dense_features, "sparse_features": sparse_features}
graph = tracer.trace(model, concrete_args=concrete_args)
gm = symbolic_trace(model, concrete_args=concrete_args)
else:
graph = tracer.trace(model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
gm = symbolic_trace(model)
model.eval()
gm.eval()

View File

@ -2,8 +2,8 @@ import torch
import torchvision
import torchvision.models as tm
from packaging import version
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
from colossalai.fx import symbolic_trace
def test_torchvision_models():
@ -20,7 +20,6 @@ def test_torchvision_models():
torch.backends.cudnn.deterministic = True
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224)
for model_cls in MODEL_LIST:
@ -30,10 +29,7 @@ def test_torchvision_models():
else:
model = model_cls()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
gm = symbolic_trace(model)
model.eval()
gm.eval()