From adf5054ff8d1884053b5631585b479ffcd2279f4 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 1 Aug 2022 16:08:28 +0800 Subject: [PATCH] [fx] fixed torchaudio conformer tracing (#1392) --- .../test_torchaudio_tacotron.py | 2 +- .../test_torchaudio_transformer.py | 12 +++++++++--- .../test_torchaudio_wave2vec.py | 2 +- .../test_torchaudio_model/torchaudio_utils.py | 11 +++++++---- 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py index 165ac6bb0..2073c4689 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py @@ -31,7 +31,7 @@ def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5): ) -@pytest.mark.skip +@pytest.mark.skip("Tracing failed") def test_tacotron_model(): n_mels = 80 n_batch = 3 diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py index fb473039b..fbe24a8cd 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py @@ -4,7 +4,6 @@ from torchaudio.models import Emformer, Conformer import pytest -@pytest.mark.skip def test_conformer(): input_dim = 80 batch_size = 10 @@ -27,10 +26,17 @@ def test_conformer(): input = torch.rand(batch_size, int(lengths.max()), input_dim) return dict(input=input, lengths=lengths) - trace_and_compare(model, data_gen, need_meta=False, need_concrete=True) + def kwargs_transform(data): + new_data = {} + + for k, v in data.items(): + new_data[f'{k}_1'] = v + return new_data + + trace_and_compare(model, data_gen, need_meta=False, need_concrete=True, kwargs_transform=kwargs_transform) -@pytest.mark.skip +@pytest.mark.skip("Tracing failed") def test_emformer(): input_dim = 128 batch_size = 10 diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py index fe25ab97f..e8729b83f 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py @@ -40,7 +40,7 @@ def _smoke_test(model, device): trace_and_compare(model, data_gen, need_meta=True, need_concrete=False) -@pytest.mark.skip +@pytest.mark.skip("Tracing failed") def test_wav2vec(): for model_fn in MODEL_LIST: _smoke_test(model_fn(), 'cpu') diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py index cee555df3..894810fe6 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py @@ -3,21 +3,24 @@ import torch from torch.fx import GraphModule, Tracer -def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False): +def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False): data = data_gen() concrete_args = data if need_concrete else {} meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {} tracer = ColoTracer() + model.eval() + graph = tracer.trace(root=model, concrete_args=concrete_args, meta_args=meta_args) gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() - model.eval() - gm.eval() - with torch.no_grad(): non_fx_out = model(**data) + + if kwargs_transform: + data = kwargs_transform(data) + fx_out = gm(**data) if isinstance(fx_out, tuple): for non_fx, fx in zip(non_fx_out, fx_out):