mirror of https://github.com/hpcaitech/ColossalAI
[fx] fixed torchaudio conformer tracing (#1392)
parent
7d6293927f
commit
adf5054ff8
|
@ -31,7 +31,7 @@ def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip("Tracing failed")
|
||||||
def test_tacotron_model():
|
def test_tacotron_model():
|
||||||
n_mels = 80
|
n_mels = 80
|
||||||
n_batch = 3
|
n_batch = 3
|
||||||
|
|
|
@ -4,7 +4,6 @@ from torchaudio.models import Emformer, Conformer
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
def test_conformer():
|
def test_conformer():
|
||||||
input_dim = 80
|
input_dim = 80
|
||||||
batch_size = 10
|
batch_size = 10
|
||||||
|
@ -27,10 +26,17 @@ def test_conformer():
|
||||||
input = torch.rand(batch_size, int(lengths.max()), input_dim)
|
input = torch.rand(batch_size, int(lengths.max()), input_dim)
|
||||||
return dict(input=input, lengths=lengths)
|
return dict(input=input, lengths=lengths)
|
||||||
|
|
||||||
trace_and_compare(model, data_gen, need_meta=False, need_concrete=True)
|
def kwargs_transform(data):
|
||||||
|
new_data = {}
|
||||||
|
|
||||||
|
for k, v in data.items():
|
||||||
|
new_data[f'{k}_1'] = v
|
||||||
|
return new_data
|
||||||
|
|
||||||
|
trace_and_compare(model, data_gen, need_meta=False, need_concrete=True, kwargs_transform=kwargs_transform)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip("Tracing failed")
|
||||||
def test_emformer():
|
def test_emformer():
|
||||||
input_dim = 128
|
input_dim = 128
|
||||||
batch_size = 10
|
batch_size = 10
|
||||||
|
|
|
@ -40,7 +40,7 @@ def _smoke_test(model, device):
|
||||||
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
|
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip("Tracing failed")
|
||||||
def test_wav2vec():
|
def test_wav2vec():
|
||||||
for model_fn in MODEL_LIST:
|
for model_fn in MODEL_LIST:
|
||||||
_smoke_test(model_fn(), 'cpu')
|
_smoke_test(model_fn(), 'cpu')
|
||||||
|
|
|
@ -3,21 +3,24 @@ import torch
|
||||||
from torch.fx import GraphModule, Tracer
|
from torch.fx import GraphModule, Tracer
|
||||||
|
|
||||||
|
|
||||||
def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False):
|
def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False):
|
||||||
data = data_gen()
|
data = data_gen()
|
||||||
concrete_args = data if need_concrete else {}
|
concrete_args = data if need_concrete else {}
|
||||||
meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {}
|
meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {}
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
graph = tracer.trace(root=model, concrete_args=concrete_args, meta_args=meta_args)
|
graph = tracer.trace(root=model, concrete_args=concrete_args, meta_args=meta_args)
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
model.eval()
|
|
||||||
gm.eval()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
non_fx_out = model(**data)
|
non_fx_out = model(**data)
|
||||||
|
|
||||||
|
if kwargs_transform:
|
||||||
|
data = kwargs_transform(data)
|
||||||
|
|
||||||
fx_out = gm(**data)
|
fx_out = gm(**data)
|
||||||
if isinstance(fx_out, tuple):
|
if isinstance(fx_out, tuple):
|
||||||
for non_fx, fx in zip(non_fx_out, fx_out):
|
for non_fx, fx in zip(non_fx_out, fx_out):
|
||||||
|
|
Loading…
Reference in New Issue