[fx] fixed torchaudio conformer tracing (#1392)

pull/1395/head
Frank Lee 2022-08-01 16:08:28 +08:00 committed by GitHub
parent 7d6293927f
commit adf5054ff8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 9 deletions

View File

@ -31,7 +31,7 @@ def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5):
)
@pytest.mark.skip
@pytest.mark.skip("Tracing failed")
def test_tacotron_model():
n_mels = 80
n_batch = 3

View File

@ -4,7 +4,6 @@ from torchaudio.models import Emformer, Conformer
import pytest
@pytest.mark.skip
def test_conformer():
input_dim = 80
batch_size = 10
@ -27,10 +26,17 @@ def test_conformer():
input = torch.rand(batch_size, int(lengths.max()), input_dim)
return dict(input=input, lengths=lengths)
trace_and_compare(model, data_gen, need_meta=False, need_concrete=True)
def kwargs_transform(data):
new_data = {}
for k, v in data.items():
new_data[f'{k}_1'] = v
return new_data
trace_and_compare(model, data_gen, need_meta=False, need_concrete=True, kwargs_transform=kwargs_transform)
@pytest.mark.skip
@pytest.mark.skip("Tracing failed")
def test_emformer():
input_dim = 128
batch_size = 10

View File

@ -40,7 +40,7 @@ def _smoke_test(model, device):
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
@pytest.mark.skip
@pytest.mark.skip("Tracing failed")
def test_wav2vec():
for model_fn in MODEL_LIST:
_smoke_test(model_fn(), 'cpu')

View File

@ -3,21 +3,24 @@ import torch
from torch.fx import GraphModule, Tracer
def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False):
def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False):
data = data_gen()
concrete_args = data if need_concrete else {}
meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {}
tracer = ColoTracer()
model.eval()
graph = tracer.trace(root=model, concrete_args=concrete_args, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
model.eval()
gm.eval()
with torch.no_grad():
non_fx_out = model(**data)
if kwargs_transform:
data = kwargs_transform(data)
fx_out = gm(**data)
if isinstance(fx_out, tuple):
for non_fx, fx in zip(non_fx_out, fx_out):