[fx] fixed torchaudio conformer tracing (#1392)

pull/1395/head
Frank Lee 2022-08-01 16:08:28 +08:00 committed by GitHub
parent 7d6293927f
commit adf5054ff8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 9 deletions

View File

@ -31,7 +31,7 @@ def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5):
) )
@pytest.mark.skip @pytest.mark.skip("Tracing failed")
def test_tacotron_model(): def test_tacotron_model():
n_mels = 80 n_mels = 80
n_batch = 3 n_batch = 3

View File

@ -4,7 +4,6 @@ from torchaudio.models import Emformer, Conformer
import pytest import pytest
@pytest.mark.skip
def test_conformer(): def test_conformer():
input_dim = 80 input_dim = 80
batch_size = 10 batch_size = 10
@ -27,10 +26,17 @@ def test_conformer():
input = torch.rand(batch_size, int(lengths.max()), input_dim) input = torch.rand(batch_size, int(lengths.max()), input_dim)
return dict(input=input, lengths=lengths) return dict(input=input, lengths=lengths)
trace_and_compare(model, data_gen, need_meta=False, need_concrete=True) def kwargs_transform(data):
new_data = {}
for k, v in data.items():
new_data[f'{k}_1'] = v
return new_data
trace_and_compare(model, data_gen, need_meta=False, need_concrete=True, kwargs_transform=kwargs_transform)
@pytest.mark.skip @pytest.mark.skip("Tracing failed")
def test_emformer(): def test_emformer():
input_dim = 128 input_dim = 128
batch_size = 10 batch_size = 10

View File

@ -40,7 +40,7 @@ def _smoke_test(model, device):
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False) trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
@pytest.mark.skip @pytest.mark.skip("Tracing failed")
def test_wav2vec(): def test_wav2vec():
for model_fn in MODEL_LIST: for model_fn in MODEL_LIST:
_smoke_test(model_fn(), 'cpu') _smoke_test(model_fn(), 'cpu')

View File

@ -3,21 +3,24 @@ import torch
from torch.fx import GraphModule, Tracer from torch.fx import GraphModule, Tracer
def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False): def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False):
data = data_gen() data = data_gen()
concrete_args = data if need_concrete else {} concrete_args = data if need_concrete else {}
meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {} meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {}
tracer = ColoTracer() tracer = ColoTracer()
model.eval()
graph = tracer.trace(root=model, concrete_args=concrete_args, meta_args=meta_args) graph = tracer.trace(root=model, concrete_args=concrete_args, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__) gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile() gm.recompile()
model.eval()
gm.eval()
with torch.no_grad(): with torch.no_grad():
non_fx_out = model(**data) non_fx_out = model(**data)
if kwargs_transform:
data = kwargs_transform(data)
fx_out = gm(**data) fx_out = gm(**data)
if isinstance(fx_out, tuple): if isinstance(fx_out, tuple):
for non_fx, fx in zip(non_fx_out, fx_out): for non_fx, fx in zip(non_fx_out, fx_out):