mirror of https://github.com/hpcaitech/ColossalAI
[fx] fixed torchaudio conformer tracing (#1392)
parent
7d6293927f
commit
adf5054ff8
|
@ -31,7 +31,7 @@ def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.skip("Tracing failed")
|
||||
def test_tacotron_model():
|
||||
n_mels = 80
|
||||
n_batch = 3
|
||||
|
|
|
@ -4,7 +4,6 @@ from torchaudio.models import Emformer, Conformer
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_conformer():
|
||||
input_dim = 80
|
||||
batch_size = 10
|
||||
|
@ -27,10 +26,17 @@ def test_conformer():
|
|||
input = torch.rand(batch_size, int(lengths.max()), input_dim)
|
||||
return dict(input=input, lengths=lengths)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=False, need_concrete=True)
|
||||
def kwargs_transform(data):
|
||||
new_data = {}
|
||||
|
||||
for k, v in data.items():
|
||||
new_data[f'{k}_1'] = v
|
||||
return new_data
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=False, need_concrete=True, kwargs_transform=kwargs_transform)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.skip("Tracing failed")
|
||||
def test_emformer():
|
||||
input_dim = 128
|
||||
batch_size = 10
|
||||
|
|
|
@ -40,7 +40,7 @@ def _smoke_test(model, device):
|
|||
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.skip("Tracing failed")
|
||||
def test_wav2vec():
|
||||
for model_fn in MODEL_LIST:
|
||||
_smoke_test(model_fn(), 'cpu')
|
||||
|
|
|
@ -3,21 +3,24 @@ import torch
|
|||
from torch.fx import GraphModule, Tracer
|
||||
|
||||
|
||||
def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False):
|
||||
def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False):
|
||||
data = data_gen()
|
||||
concrete_args = data if need_concrete else {}
|
||||
meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {}
|
||||
tracer = ColoTracer()
|
||||
|
||||
model.eval()
|
||||
|
||||
graph = tracer.trace(root=model, concrete_args=concrete_args, meta_args=meta_args)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
model.eval()
|
||||
gm.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
non_fx_out = model(**data)
|
||||
|
||||
if kwargs_transform:
|
||||
data = kwargs_transform(data)
|
||||
|
||||
fx_out = gm(**data)
|
||||
if isinstance(fx_out, tuple):
|
||||
for non_fx, fx in zip(non_fx_out, fx_out):
|
||||
|
|
Loading…
Reference in New Issue