mirror of https://github.com/hpcaitech/ColossalAI
[fx] add a symbolic_trace api. (#1812)
* [fx] add a symbolic_trace api. * [fx] fix import errors.pull/1817/head
parent
350ccc0481
commit
441d584e4a
|
@ -1,4 +1,4 @@
|
|||
from ._compatibility import compatibility, is_compatible_with_meta
|
||||
from .graph_module import ColoGraphModule
|
||||
from .passes import MetaInfoProp
|
||||
from .tracer import ColoTracer, meta_trace
|
||||
from ._compatibility import compatibility, is_compatible_with_meta
|
||||
from .graph_module import ColoGraphModule
|
||||
from .passes import MetaInfoProp
|
||||
from .tracer import ColoTracer, meta_trace, symbolic_trace
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from colossalai.fx.tracer.meta_patch.patched_function.python_ops import operator_getitem
|
||||
|
||||
from ._meta_trace import meta_trace
|
||||
from ._symbolic_trace import symbolic_trace
|
||||
from .tracer import ColoTracer
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.fx import ColoGraphModule
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
|
||||
from .tracer import ColoTracer
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def symbolic_trace(
|
||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||
concrete_args: Optional[Dict[str, Any]] = None,
|
||||
meta_args: Optional[Dict[str, Any]] = None,
|
||||
) -> ColoGraphModule:
|
||||
"""
|
||||
Symbolic tracing API
|
||||
|
||||
Given an ``nn.Module`` or function instance ``root``, this function will return a ``ColoGraphModule``
|
||||
constructed by recording operations seen while tracing through ``root``.
|
||||
|
||||
With ``meta_args`` and ``concrete_args``, we can trace the model that are untraceable subject to control flow.
|
||||
If specified using ``meta_args`` only, the tracing can be done ahead of time.
|
||||
|
||||
Note that both ``meta_args`` and ``concrete_args`` are kwargs, which contains the key of the argument's names
|
||||
and the value of the argument's values.
|
||||
|
||||
Uses:
|
||||
>>> model = ...
|
||||
|
||||
# if this works
|
||||
>>> gm = symbolic_trace(model)
|
||||
|
||||
# else try this
|
||||
>>> gm = symbolic_trace(model, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')})
|
||||
|
||||
# else try this
|
||||
>>> gm = symbolic_trace(model, concrete_args={'x': torch.rand(1, 3, 224, 224)})
|
||||
|
||||
Args:
|
||||
root (Union[torch.nn.Module, Callable[..., Any]]): Module or function to be traced and converted
|
||||
into a Graph representation.
|
||||
concrete_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized. Defaults to None.
|
||||
meta_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized, special for ``ColoTracer``.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
ColoGraphModule: A ``ColoGraphModule`` created from the recorded operations from ``root``.
|
||||
|
||||
Warnings:
|
||||
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
|
||||
|
||||
"""
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(root, concrete_args, meta_args)
|
||||
name = (root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__)
|
||||
return ColoGraphModule(tracer.root, graph, name)
|
|
@ -3,24 +3,19 @@ from numpy import isin
|
|||
from torch.fx import GraphModule
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx import symbolic_trace
|
||||
|
||||
|
||||
def trace_model_and_compare_output(model, data_gen):
|
||||
# must turn on eval mode to ensure the output is consistent
|
||||
model.eval()
|
||||
|
||||
# make sure that the model is traceable
|
||||
tracer = ColoTracer()
|
||||
|
||||
try:
|
||||
kwargs = data_gen()
|
||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = symbolic_trace(model, meta_args=meta_args)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
# run forward
|
||||
inputs = data_gen()
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from utils import trace_model_and_compare_output
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 2
|
||||
SEQ_LENGTH = 16
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from utils import trace_model_and_compare_output
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 2
|
||||
SEQ_LENGTH = 16
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import pytest
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from utils import trace_model_and_compare_output
|
||||
|
||||
import transformers
|
||||
from colossalai.fx import ColoTracer
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
|
||||
from colossalai.fx import symbolic_trace
|
||||
|
||||
try:
|
||||
import diffusers
|
||||
|
@ -32,11 +31,7 @@ def test_vae():
|
|||
model = model_cls()
|
||||
sample = torch.zeros(LATENTS_SHAPE)
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(root=model)
|
||||
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
gm = symbolic_trace(model)
|
||||
|
||||
model.eval()
|
||||
gm.eval()
|
||||
|
@ -98,11 +93,7 @@ def test_unet():
|
|||
model = model_cls()
|
||||
sample = torch.zeros(LATENTS_SHAPE)
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(root=model)
|
||||
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
gm = symbolic_trace(model)
|
||||
|
||||
model.eval()
|
||||
gm.eval()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from utils import trace_model_and_compare_output
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGTH = 16
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from utils import trace_model_and_compare_output
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGTH = 16
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from utils import trace_model_and_compare_output
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGTH = 16
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
import pytest
|
||||
import timm.models as tm
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx import symbolic_trace
|
||||
|
||||
|
||||
def trace_and_compare(model_cls, tracer, data, meta_args=None):
|
||||
def trace_and_compare(model_cls, data, meta_args=None):
|
||||
# trace
|
||||
model = model_cls()
|
||||
|
||||
|
@ -15,9 +14,7 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None):
|
|||
# without this statement, the torch.nn.functional.batch_norm will always be in training mode
|
||||
model.eval()
|
||||
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
gm = symbolic_trace(model, meta_args=meta_args)
|
||||
|
||||
# run forward
|
||||
with torch.no_grad():
|
||||
|
@ -49,11 +46,10 @@ def test_timm_models_without_control_flow():
|
|||
tm.deit_base_distilled_patch16_224,
|
||||
]
|
||||
|
||||
tracer = ColoTracer()
|
||||
data = torch.rand(2, 3, 224, 224)
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
trace_and_compare(model_cls, tracer, data)
|
||||
trace_and_compare(model_cls, data)
|
||||
|
||||
|
||||
def test_timm_models_with_control_flow():
|
||||
|
@ -64,13 +60,12 @@ def test_timm_models_with_control_flow():
|
|||
tm.swin_transformer.swin_base_patch4_window7_224
|
||||
]
|
||||
|
||||
tracer = ColoTracer()
|
||||
data = torch.rand(2, 3, 224, 224)
|
||||
|
||||
meta_args = {'x': data.to('meta')}
|
||||
|
||||
for model_cls in MODEL_LIST_WITH_CONTROL_FLOW:
|
||||
trace_and_compare(model_cls, tracer, data, meta_args)
|
||||
trace_and_compare(model_cls, data, meta_args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,20 +1,16 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule, Tracer
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx import symbolic_trace
|
||||
|
||||
|
||||
def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False):
|
||||
data = data_gen()
|
||||
concrete_args = data if need_concrete else {}
|
||||
meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {}
|
||||
tracer = ColoTracer()
|
||||
|
||||
model.eval()
|
||||
|
||||
graph = tracer.trace(root=model, concrete_args=concrete_args, meta_args=meta_args)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
gm = symbolic_trace(model, concrete_args=concrete_args, meta_args=meta_args)
|
||||
|
||||
with torch.no_grad():
|
||||
non_fx_out = model(**data)
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.fx.tracer import meta_patch
|
||||
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.fx import symbolic_trace
|
||||
|
||||
try:
|
||||
from torchrec.models import deepfm
|
||||
|
@ -14,8 +12,6 @@ try:
|
|||
except ImportError:
|
||||
NOT_TORCHREC = True
|
||||
|
||||
from torch.fx import GraphModule
|
||||
|
||||
BATCH = 2
|
||||
SHAPE = 10
|
||||
|
||||
|
@ -43,9 +39,6 @@ def test_torchrec_deepfm_models():
|
|||
# Dense Features
|
||||
features = torch.rand((BATCH, SHAPE))
|
||||
|
||||
# Tracer
|
||||
tracer = ColoTracer()
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
# Initializing model
|
||||
if model_cls == deepfm.DenseArch:
|
||||
|
@ -60,9 +53,7 @@ def test_torchrec_deepfm_models():
|
|||
model = model_cls(ebc)
|
||||
|
||||
# Setup GraphModule
|
||||
graph = tracer.trace(model)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
gm = symbolic_trace(model)
|
||||
|
||||
model.eval()
|
||||
gm.eval()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.fx import symbolic_trace
|
||||
|
||||
try:
|
||||
from torchrec.models import dlrm
|
||||
|
@ -12,7 +12,6 @@ except ImportError:
|
|||
NOT_TORCHREC = True
|
||||
|
||||
import pytest
|
||||
from torch.fx import GraphModule
|
||||
|
||||
BATCH = 2
|
||||
SHAPE = 10
|
||||
|
@ -51,8 +50,6 @@ def test_torchrec_dlrm_models():
|
|||
|
||||
# Sparse Features
|
||||
sparse_features = torch.rand((BATCH, len(keys), SHAPE))
|
||||
# Tracer
|
||||
tracer = ColoTracer()
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
# Initializing model
|
||||
|
@ -77,12 +74,9 @@ def test_torchrec_dlrm_models():
|
|||
# Setup GraphModule
|
||||
if model_cls == dlrm.InteractionV2Arch:
|
||||
concrete_args = {"dense_features": dense_features, "sparse_features": sparse_features}
|
||||
graph = tracer.trace(model, concrete_args=concrete_args)
|
||||
gm = symbolic_trace(model, concrete_args=concrete_args)
|
||||
else:
|
||||
graph = tracer.trace(model)
|
||||
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
gm = symbolic_trace(model)
|
||||
|
||||
model.eval()
|
||||
gm.eval()
|
||||
|
|
|
@ -2,8 +2,8 @@ import torch
|
|||
import torchvision
|
||||
import torchvision.models as tm
|
||||
from packaging import version
|
||||
from colossalai.fx import ColoTracer
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.fx import symbolic_trace
|
||||
|
||||
|
||||
def test_torchvision_models():
|
||||
|
@ -20,7 +20,6 @@ def test_torchvision_models():
|
|||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
tracer = ColoTracer()
|
||||
data = torch.rand(2, 3, 224, 224)
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
|
@ -30,10 +29,7 @@ def test_torchvision_models():
|
|||
else:
|
||||
model = model_cls()
|
||||
|
||||
graph = tracer.trace(root=model)
|
||||
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
gm = symbolic_trace(model)
|
||||
|
||||
model.eval()
|
||||
gm.eval()
|
||||
|
|
Loading…
Reference in New Issue