mirror of https://github.com/hpcaitech/ColossalAI
[tests] model zoo add torchaudio models (#3138)
* [tests] model zoo add torchaudio models * [tests] refactor torchaudio wavernn * [tests] refactor fx torchaudio testspull/3139/head^2
parent
6d48eb0560
commit
14a115000b
|
@ -1,4 +1,4 @@
|
|||
from . import diffusers, timm, torchvision, transformers
|
||||
from . import diffusers, timm, torchaudio, torchvision, transformers
|
||||
from .registry import model_zoo
|
||||
|
||||
__all__ = ['model_zoo']
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .torchaudio import *
|
|
@ -0,0 +1,130 @@
|
|||
import torch
|
||||
import torchaudio.models as tm
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
INPUT_DIM = 80
|
||||
IN_FEATURES = 16
|
||||
N_TIME = 20
|
||||
KERNEL_SIZE = 5
|
||||
HOP_LENGTH = 20
|
||||
N_CLASSES = 10
|
||||
N_FREQ = 16
|
||||
N_MELS = 80
|
||||
|
||||
|
||||
def conformer_data_gen_fn():
|
||||
lengths = torch.randint(1, 400, (4,))
|
||||
input = torch.rand(4, int(lengths.max()), INPUT_DIM)
|
||||
return dict(input=input, lengths=lengths)
|
||||
|
||||
|
||||
transformer_output_transform_fn = lambda outputs: dict(frames=outputs[0], lengths=outputs[1])
|
||||
|
||||
model_zoo.register(name='torchaudio_conformer',
|
||||
model_fn=lambda: tm.Conformer(
|
||||
input_dim=INPUT_DIM, num_heads=4, ffn_dim=128, num_layers=4, depthwise_conv_kernel_size=31),
|
||||
data_gen_fn=conformer_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn)
|
||||
|
||||
single_output_transform_fn = lambda output: dict(output=output)
|
||||
|
||||
model_zoo.register(name='torchaudio_convtasnet',
|
||||
model_fn=tm.ConvTasNet,
|
||||
data_gen_fn=lambda: dict(input=torch.rand(4, 1, 8)),
|
||||
output_transform_fn=single_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
model_zoo.register(name='torchaudio_deepspeech',
|
||||
model_fn=lambda: tm.DeepSpeech(IN_FEATURES, n_hidden=128, n_class=4),
|
||||
data_gen_fn=lambda: dict(x=torch.rand(4, 1, 10, IN_FEATURES)),
|
||||
output_transform_fn=single_output_transform_fn)
|
||||
|
||||
|
||||
def emformer_data_gen_fn():
|
||||
input = torch.rand(4, 400, IN_FEATURES)
|
||||
lengths = torch.randint(1, 200, (4,))
|
||||
return dict(input=input, lengths=lengths)
|
||||
|
||||
|
||||
model_zoo.register(
|
||||
name='torchaudio_emformer',
|
||||
model_fn=lambda: tm.Emformer(input_dim=IN_FEATURES, num_heads=4, ffn_dim=128, num_layers=4, segment_length=4),
|
||||
data_gen_fn=emformer_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
model_zoo.register(name='torchaudio_wav2letter_waveform',
|
||||
model_fn=lambda: tm.Wav2Letter(input_type='waveform', num_features=40),
|
||||
data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)),
|
||||
output_transform_fn=single_output_transform_fn)
|
||||
|
||||
model_zoo.register(name='torchaudio_wav2letter_mfcc',
|
||||
model_fn=lambda: tm.Wav2Letter(input_type='mfcc', num_features=40),
|
||||
data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)),
|
||||
output_transform_fn=single_output_transform_fn)
|
||||
|
||||
|
||||
def wavernn_data_gen_fn():
|
||||
waveform = torch.rand(4, 1, (N_TIME - KERNEL_SIZE + 1) * HOP_LENGTH)
|
||||
specgram = torch.rand(4, 1, N_FREQ, N_TIME)
|
||||
return dict(waveform=waveform, specgram=specgram)
|
||||
|
||||
|
||||
model_zoo.register(name='torchaudio_wavernn',
|
||||
model_fn=lambda: tm.WaveRNN(upsample_scales=[2, 2, 5],
|
||||
n_classes=N_CLASSES,
|
||||
hop_length=HOP_LENGTH,
|
||||
kernel_size=KERNEL_SIZE,
|
||||
n_freq=N_FREQ,
|
||||
n_res_block=2,
|
||||
n_rnn=64,
|
||||
n_fc=64,
|
||||
n_hidden=16,
|
||||
n_output=16),
|
||||
data_gen_fn=wavernn_data_gen_fn,
|
||||
output_transform_fn=single_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
|
||||
def tacotron_data_gen_fn():
|
||||
n_batch = 4
|
||||
max_text_length = 100
|
||||
max_mel_specgram_length = 300
|
||||
tokens = torch.randint(0, 148, (n_batch, max_text_length))
|
||||
token_lengths = max_text_length * torch.ones((n_batch,))
|
||||
mel_specgram = torch.rand(n_batch, N_MELS, max_mel_specgram_length)
|
||||
mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,))
|
||||
return dict(tokens=tokens,
|
||||
token_lengths=token_lengths,
|
||||
mel_specgram=mel_specgram,
|
||||
mel_specgram_lengths=mel_specgram_lengths)
|
||||
|
||||
|
||||
model_zoo.register(
|
||||
name='torchaudio_tacotron',
|
||||
model_fn=lambda: tm.Tacotron2(n_mels=N_MELS),
|
||||
data_gen_fn=tacotron_data_gen_fn,
|
||||
output_transform_fn=lambda outputs: dict(
|
||||
spectrogram_before=outputs[0], spectrogram_after=outputs[1], stop_tokens=outputs[2], attn_weights=outputs[3]),
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
|
||||
def wav2vec_data_gen_fn():
|
||||
batch_size, num_frames = 4, 400
|
||||
waveforms = torch.randn(batch_size, num_frames)
|
||||
lengths = torch.randint(0, num_frames, (batch_size,))
|
||||
return dict(waveforms=waveforms, lengths=lengths)
|
||||
|
||||
|
||||
model_zoo.register(name='torchaudio_wav2vec2_base',
|
||||
model_fn=tm.wav2vec2_base,
|
||||
data_gen_fn=wav2vec_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
model_zoo.register(name='torchaudio_hubert_base',
|
||||
model_fn=tm.hubert_base,
|
||||
data_gen_fn=wav2vec_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
|
@ -1,145 +0,0 @@
|
|||
import torch
|
||||
from torchaudio_utils import trace_and_compare
|
||||
from torchaudio.models import ConvTasNet, DeepSpeech, Wav2Letter, WaveRNN
|
||||
from torchaudio.models.wavernn import MelResNet, UpsampleNetwork
|
||||
import pytest
|
||||
|
||||
|
||||
def test_wave2letter_waveform():
|
||||
batch_size = 2
|
||||
num_features = 1
|
||||
num_classes = 40
|
||||
input_length = 320
|
||||
|
||||
model = Wav2Letter(num_classes=num_classes, num_features=num_features)
|
||||
|
||||
def data_gen():
|
||||
x = torch.rand(batch_size, num_features, input_length)
|
||||
return dict(x=x)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
|
||||
|
||||
|
||||
def test_wave2letter_mfcc():
|
||||
batch_size = 2
|
||||
num_features = 13
|
||||
num_classes = 40
|
||||
input_length = 2
|
||||
|
||||
model = Wav2Letter(num_classes=num_classes, input_type="mfcc", num_features=num_features)
|
||||
|
||||
def data_gen():
|
||||
x = torch.rand(batch_size, num_features, input_length)
|
||||
return dict(x=x)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
|
||||
|
||||
|
||||
def test_melresnet_waveform():
|
||||
n_batch = 2
|
||||
n_time = 200
|
||||
n_freq = 100
|
||||
n_output = 128
|
||||
n_res_block = 10
|
||||
n_hidden = 128
|
||||
kernel_size = 5
|
||||
|
||||
model = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
|
||||
|
||||
def data_gen():
|
||||
x = torch.rand(n_batch, n_freq, n_time)
|
||||
return dict(specgram=x)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
|
||||
|
||||
|
||||
def test_upsample_network_waveform():
|
||||
upsample_scales = [5, 5, 8]
|
||||
n_batch = 2
|
||||
n_time = 200
|
||||
n_freq = 100
|
||||
n_output = 64
|
||||
n_res_block = 10
|
||||
n_hidden = 32
|
||||
kernel_size = 5
|
||||
|
||||
total_scale = 1
|
||||
for upsample_scale in upsample_scales:
|
||||
total_scale *= upsample_scale
|
||||
|
||||
model = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
|
||||
|
||||
def data_gen():
|
||||
x = torch.rand(n_batch, n_freq, n_time)
|
||||
return dict(specgram=x)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
|
||||
|
||||
|
||||
def test_wavernn_waveform():
|
||||
upsample_scales = [2, 2, 5]
|
||||
n_rnn = 16
|
||||
n_fc = 16
|
||||
n_classes = 10
|
||||
hop_length = 20
|
||||
n_batch = 2
|
||||
n_time = 20
|
||||
n_freq = 10
|
||||
n_output = 16
|
||||
n_res_block = 3
|
||||
n_hidden = 16
|
||||
kernel_size = 5
|
||||
|
||||
model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block, n_rnn, n_fc, kernel_size, n_freq, n_hidden,
|
||||
n_output)
|
||||
|
||||
def data_gen():
|
||||
x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
|
||||
mels = torch.rand(n_batch, 1, n_freq, n_time)
|
||||
return dict(waveform=x, specgram=mels)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
|
||||
|
||||
|
||||
def test_convtasnet_config():
|
||||
batch_size = 32
|
||||
num_frames = 800
|
||||
|
||||
model = ConvTasNet()
|
||||
|
||||
def data_gen():
|
||||
tensor = torch.rand(batch_size, 1, num_frames)
|
||||
return dict(input=tensor)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
|
||||
|
||||
|
||||
def test_deepspeech():
|
||||
n_batch = 2
|
||||
n_feature = 1
|
||||
n_channel = 1
|
||||
n_class = 40
|
||||
n_time = 32
|
||||
|
||||
model = DeepSpeech(n_feature=n_feature, n_class=n_class)
|
||||
|
||||
def data_gen():
|
||||
x = torch.rand(n_batch, n_channel, n_time, n_feature)
|
||||
return dict(x=x)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
TEST_LIST = [
|
||||
test_wave2letter_waveform,
|
||||
test_wave2letter_mfcc,
|
||||
test_melresnet_waveform,
|
||||
test_upsample_network_waveform,
|
||||
test_wavernn_waveform,
|
||||
test_convtasnet_config,
|
||||
test_deepspeech,
|
||||
]
|
||||
|
||||
for test_fn in TEST_LIST:
|
||||
test_fn()
|
|
@ -0,0 +1,22 @@
|
|||
import re
|
||||
|
||||
import torch
|
||||
from torchaudio_utils import trace_and_compare
|
||||
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
def test_torchaudio_models():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
|
||||
# FIXME(ver217): temporarily skip these models
|
||||
if re.search(f'(conformer|emformer|tacotron|wav2vec2_base|hubert_base)', name):
|
||||
continue
|
||||
model = model_fn()
|
||||
trace_and_compare(model,
|
||||
data_gen_fn,
|
||||
output_transform_fn,
|
||||
need_meta=(attribute is not None and attribute.has_control_flow))
|
|
@ -1,57 +0,0 @@
|
|||
import torch
|
||||
from torchaudio.models import Tacotron2
|
||||
from torchaudio_utils import trace_and_compare
|
||||
import pytest
|
||||
|
||||
|
||||
def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5):
|
||||
return Tacotron2(
|
||||
mask_padding=False,
|
||||
n_mels=n_mels,
|
||||
n_symbol=20,
|
||||
n_frames_per_step=1,
|
||||
symbol_embedding_dim=32,
|
||||
encoder_embedding_dim=32,
|
||||
encoder_n_convolution=3,
|
||||
encoder_kernel_size=5,
|
||||
decoder_rnn_dim=32,
|
||||
decoder_max_step=decoder_max_step,
|
||||
decoder_dropout=0.1,
|
||||
decoder_early_stopping=True,
|
||||
attention_rnn_dim=32,
|
||||
attention_hidden_dim=32,
|
||||
attention_location_n_filter=32,
|
||||
attention_location_kernel_size=31,
|
||||
attention_dropout=0.1,
|
||||
prenet_dim=32,
|
||||
postnet_n_convolution=5,
|
||||
postnet_kernel_size=5,
|
||||
postnet_embedding_dim=512,
|
||||
gate_threshold=gate_threshold,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip("Tracing failed")
|
||||
def test_tacotron_model():
|
||||
n_mels = 80
|
||||
n_batch = 3
|
||||
max_mel_specgram_length = 300
|
||||
max_text_length = 100
|
||||
|
||||
model = _get_tacotron2_model(n_mels)
|
||||
|
||||
def data_gen():
|
||||
text = torch.randint(0, 148, (n_batch, max_text_length))
|
||||
text_lengths = max_text_length * torch.ones((n_batch,))
|
||||
mel_specgram = torch.rand(n_batch, n_mels, max_mel_specgram_length)
|
||||
mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,))
|
||||
return dict(tokens=text,
|
||||
token_lengths=text_lengths,
|
||||
mel_specgram=mel_specgram,
|
||||
mel_specgram_lengths=mel_specgram_lengths)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tacotron_model()
|
|
@ -1,67 +0,0 @@
|
|||
import torch
|
||||
from torchaudio_utils import trace_and_compare
|
||||
from torchaudio.models import Emformer, Conformer
|
||||
import pytest
|
||||
|
||||
|
||||
def test_conformer():
|
||||
input_dim = 80
|
||||
batch_size = 10
|
||||
num_frames = 400
|
||||
num_heads = 4
|
||||
ffn_dim = 128
|
||||
num_layers = 4
|
||||
depthwise_conv_kernel_size = 31
|
||||
|
||||
model = Conformer(
|
||||
input_dim=input_dim,
|
||||
num_heads=num_heads,
|
||||
ffn_dim=ffn_dim,
|
||||
num_layers=num_layers,
|
||||
depthwise_conv_kernel_size=depthwise_conv_kernel_size,
|
||||
)
|
||||
|
||||
def data_gen():
|
||||
lengths = torch.randint(1, num_frames, (batch_size,))
|
||||
input = torch.rand(batch_size, int(lengths.max()), input_dim)
|
||||
return dict(input=input, lengths=lengths)
|
||||
|
||||
def kwargs_transform(data):
|
||||
new_data = {}
|
||||
|
||||
for k, v in data.items():
|
||||
new_data[f'{k}_1'] = v
|
||||
return new_data
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=False, need_concrete=True, kwargs_transform=kwargs_transform)
|
||||
|
||||
|
||||
@pytest.mark.skip("Tracing failed")
|
||||
def test_emformer():
|
||||
input_dim = 128
|
||||
batch_size = 10
|
||||
num_heads = 8
|
||||
ffn_dim = 256
|
||||
num_layers = 3
|
||||
segment_length = 4
|
||||
num_frames = 400
|
||||
right_context_length = 1
|
||||
|
||||
model = Emformer(input_dim, num_heads, ffn_dim, num_layers, segment_length, right_context_length)
|
||||
|
||||
def data_gen():
|
||||
lengths = torch.randint(1, num_frames, (batch_size,))
|
||||
input = torch.rand(batch_size, num_frames, input_dim)
|
||||
return dict(input=input, lengths=lengths)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_torchaudio_transformers():
|
||||
test_conformer()
|
||||
test_emformer()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_torchaudio_transformers()
|
|
@ -1,50 +0,0 @@
|
|||
import torch
|
||||
from torchaudio.models.wav2vec2 import (
|
||||
hubert_base,
|
||||
hubert_large,
|
||||
hubert_xlarge,
|
||||
wav2vec2_base,
|
||||
wav2vec2_large,
|
||||
wav2vec2_large_lv60k,
|
||||
)
|
||||
from torchaudio_utils import trace_and_compare
|
||||
import pytest
|
||||
|
||||
MODEL_LIST = [
|
||||
hubert_base,
|
||||
hubert_large,
|
||||
hubert_xlarge,
|
||||
wav2vec2_base,
|
||||
wav2vec2_large,
|
||||
wav2vec2_large_lv60k,
|
||||
]
|
||||
|
||||
|
||||
def _smoke_test(model, device):
|
||||
model = model.to(device=device)
|
||||
|
||||
batch_size, num_frames = 3, 1024
|
||||
|
||||
def data_gen():
|
||||
waveforms = torch.randn(batch_size, num_frames, device=device)
|
||||
lengths = torch.randint(
|
||||
low=0,
|
||||
high=num_frames,
|
||||
size=[
|
||||
batch_size,
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
return dict(waveforms=waveforms, lengths=lengths)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
|
||||
|
||||
|
||||
@pytest.mark.skip("Tracing failed")
|
||||
def test_wav2vec():
|
||||
for model_fn in MODEL_LIST:
|
||||
_smoke_test(model_fn(), 'cpu')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_wav2vec()
|
|
@ -3,7 +3,7 @@ import torch
|
|||
from colossalai.fx import symbolic_trace
|
||||
|
||||
|
||||
def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False):
|
||||
def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, need_concrete=False):
|
||||
data = data_gen()
|
||||
concrete_args = data if need_concrete else {}
|
||||
meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {}
|
||||
|
@ -14,16 +14,15 @@ def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwa
|
|||
|
||||
with torch.no_grad():
|
||||
non_fx_out = model(**data)
|
||||
|
||||
if kwargs_transform:
|
||||
data = kwargs_transform(data)
|
||||
|
||||
fx_out = gm(**data)
|
||||
if isinstance(fx_out, tuple):
|
||||
for non_fx, fx in zip(non_fx_out, fx_out):
|
||||
assert torch.allclose(
|
||||
non_fx, fx, atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||
else:
|
||||
assert torch.allclose(
|
||||
fx_out, non_fx_out,
|
||||
atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||
|
||||
# compare output
|
||||
transformed_fx_out = output_transform_fn(fx_out)
|
||||
transformed_non_fx_out = output_transform_fn(non_fx_out)
|
||||
|
||||
assert len(transformed_fx_out) == len(transformed_non_fx_out)
|
||||
|
||||
for key, fx_output_val in transformed_fx_out.items():
|
||||
non_fx_output_val = transformed_non_fx_out[key]
|
||||
assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \
|
||||
f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
|
||||
|
|
Loading…
Reference in New Issue