From 14a115000b29ffe0680e3241d4dbb045389eb56e Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 15 Mar 2023 11:51:16 +0800 Subject: [PATCH] [tests] model zoo add torchaudio models (#3138) * [tests] model zoo add torchaudio models * [tests] refactor torchaudio wavernn * [tests] refactor fx torchaudio tests --- tests/kit/model_zoo/__init__.py | 2 +- tests/kit/model_zoo/torchaudio/__init__.py | 1 + tests/kit/model_zoo/torchaudio/torchaudio.py | 130 ++++++++++++++++ .../test_torchaudio_general.py | 145 ------------------ .../test_torchaudio_model.py | 22 +++ .../test_torchaudio_tacotron.py | 57 ------- .../test_torchaudio_transformer.py | 67 -------- .../test_torchaudio_wave2vec.py | 50 ------ .../test_torchaudio_model/torchaudio_utils.py | 25 ++- 9 files changed, 166 insertions(+), 333 deletions(-) create mode 100644 tests/kit/model_zoo/torchaudio/__init__.py create mode 100644 tests/kit/model_zoo/torchaudio/torchaudio.py delete mode 100644 tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_general.py create mode 100644 tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py delete mode 100644 tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py delete mode 100644 tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py delete mode 100644 tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py index 7f14d04c0..82a61626b 100644 --- a/tests/kit/model_zoo/__init__.py +++ b/tests/kit/model_zoo/__init__.py @@ -1,4 +1,4 @@ -from . import diffusers, timm, torchvision, transformers +from . import diffusers, timm, torchaudio, torchvision, transformers from .registry import model_zoo __all__ = ['model_zoo'] diff --git a/tests/kit/model_zoo/torchaudio/__init__.py b/tests/kit/model_zoo/torchaudio/__init__.py new file mode 100644 index 000000000..082eb9ebb --- /dev/null +++ b/tests/kit/model_zoo/torchaudio/__init__.py @@ -0,0 +1 @@ +from .torchaudio import * diff --git a/tests/kit/model_zoo/torchaudio/torchaudio.py b/tests/kit/model_zoo/torchaudio/torchaudio.py new file mode 100644 index 000000000..746117202 --- /dev/null +++ b/tests/kit/model_zoo/torchaudio/torchaudio.py @@ -0,0 +1,130 @@ +import torch +import torchaudio.models as tm + +from ..registry import ModelAttribute, model_zoo + +INPUT_DIM = 80 +IN_FEATURES = 16 +N_TIME = 20 +KERNEL_SIZE = 5 +HOP_LENGTH = 20 +N_CLASSES = 10 +N_FREQ = 16 +N_MELS = 80 + + +def conformer_data_gen_fn(): + lengths = torch.randint(1, 400, (4,)) + input = torch.rand(4, int(lengths.max()), INPUT_DIM) + return dict(input=input, lengths=lengths) + + +transformer_output_transform_fn = lambda outputs: dict(frames=outputs[0], lengths=outputs[1]) + +model_zoo.register(name='torchaudio_conformer', + model_fn=lambda: tm.Conformer( + input_dim=INPUT_DIM, num_heads=4, ffn_dim=128, num_layers=4, depthwise_conv_kernel_size=31), + data_gen_fn=conformer_data_gen_fn, + output_transform_fn=transformer_output_transform_fn) + +single_output_transform_fn = lambda output: dict(output=output) + +model_zoo.register(name='torchaudio_convtasnet', + model_fn=tm.ConvTasNet, + data_gen_fn=lambda: dict(input=torch.rand(4, 1, 8)), + output_transform_fn=single_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name='torchaudio_deepspeech', + model_fn=lambda: tm.DeepSpeech(IN_FEATURES, n_hidden=128, n_class=4), + data_gen_fn=lambda: dict(x=torch.rand(4, 1, 10, IN_FEATURES)), + output_transform_fn=single_output_transform_fn) + + +def emformer_data_gen_fn(): + input = torch.rand(4, 400, IN_FEATURES) + lengths = torch.randint(1, 200, (4,)) + return dict(input=input, lengths=lengths) + + +model_zoo.register( + name='torchaudio_emformer', + model_fn=lambda: tm.Emformer(input_dim=IN_FEATURES, num_heads=4, ffn_dim=128, num_layers=4, segment_length=4), + data_gen_fn=emformer_data_gen_fn, + output_transform_fn=transformer_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name='torchaudio_wav2letter_waveform', + model_fn=lambda: tm.Wav2Letter(input_type='waveform', num_features=40), + data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)), + output_transform_fn=single_output_transform_fn) + +model_zoo.register(name='torchaudio_wav2letter_mfcc', + model_fn=lambda: tm.Wav2Letter(input_type='mfcc', num_features=40), + data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)), + output_transform_fn=single_output_transform_fn) + + +def wavernn_data_gen_fn(): + waveform = torch.rand(4, 1, (N_TIME - KERNEL_SIZE + 1) * HOP_LENGTH) + specgram = torch.rand(4, 1, N_FREQ, N_TIME) + return dict(waveform=waveform, specgram=specgram) + + +model_zoo.register(name='torchaudio_wavernn', + model_fn=lambda: tm.WaveRNN(upsample_scales=[2, 2, 5], + n_classes=N_CLASSES, + hop_length=HOP_LENGTH, + kernel_size=KERNEL_SIZE, + n_freq=N_FREQ, + n_res_block=2, + n_rnn=64, + n_fc=64, + n_hidden=16, + n_output=16), + data_gen_fn=wavernn_data_gen_fn, + output_transform_fn=single_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) + + +def tacotron_data_gen_fn(): + n_batch = 4 + max_text_length = 100 + max_mel_specgram_length = 300 + tokens = torch.randint(0, 148, (n_batch, max_text_length)) + token_lengths = max_text_length * torch.ones((n_batch,)) + mel_specgram = torch.rand(n_batch, N_MELS, max_mel_specgram_length) + mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,)) + return dict(tokens=tokens, + token_lengths=token_lengths, + mel_specgram=mel_specgram, + mel_specgram_lengths=mel_specgram_lengths) + + +model_zoo.register( + name='torchaudio_tacotron', + model_fn=lambda: tm.Tacotron2(n_mels=N_MELS), + data_gen_fn=tacotron_data_gen_fn, + output_transform_fn=lambda outputs: dict( + spectrogram_before=outputs[0], spectrogram_after=outputs[1], stop_tokens=outputs[2], attn_weights=outputs[3]), + model_attribute=ModelAttribute(has_control_flow=True)) + + +def wav2vec_data_gen_fn(): + batch_size, num_frames = 4, 400 + waveforms = torch.randn(batch_size, num_frames) + lengths = torch.randint(0, num_frames, (batch_size,)) + return dict(waveforms=waveforms, lengths=lengths) + + +model_zoo.register(name='torchaudio_wav2vec2_base', + model_fn=tm.wav2vec2_base, + data_gen_fn=wav2vec_data_gen_fn, + output_transform_fn=transformer_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name='torchaudio_hubert_base', + model_fn=tm.hubert_base, + data_gen_fn=wav2vec_data_gen_fn, + output_transform_fn=transformer_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_general.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_general.py deleted file mode 100644 index b2fa8c6c0..000000000 --- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_general.py +++ /dev/null @@ -1,145 +0,0 @@ -import torch -from torchaudio_utils import trace_and_compare -from torchaudio.models import ConvTasNet, DeepSpeech, Wav2Letter, WaveRNN -from torchaudio.models.wavernn import MelResNet, UpsampleNetwork -import pytest - - -def test_wave2letter_waveform(): - batch_size = 2 - num_features = 1 - num_classes = 40 - input_length = 320 - - model = Wav2Letter(num_classes=num_classes, num_features=num_features) - - def data_gen(): - x = torch.rand(batch_size, num_features, input_length) - return dict(x=x) - - trace_and_compare(model, data_gen, need_meta=False, need_concrete=False) - - -def test_wave2letter_mfcc(): - batch_size = 2 - num_features = 13 - num_classes = 40 - input_length = 2 - - model = Wav2Letter(num_classes=num_classes, input_type="mfcc", num_features=num_features) - - def data_gen(): - x = torch.rand(batch_size, num_features, input_length) - return dict(x=x) - - trace_and_compare(model, data_gen, need_meta=False, need_concrete=False) - - -def test_melresnet_waveform(): - n_batch = 2 - n_time = 200 - n_freq = 100 - n_output = 128 - n_res_block = 10 - n_hidden = 128 - kernel_size = 5 - - model = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size) - - def data_gen(): - x = torch.rand(n_batch, n_freq, n_time) - return dict(specgram=x) - - trace_and_compare(model, data_gen, need_meta=False, need_concrete=False) - - -def test_upsample_network_waveform(): - upsample_scales = [5, 5, 8] - n_batch = 2 - n_time = 200 - n_freq = 100 - n_output = 64 - n_res_block = 10 - n_hidden = 32 - kernel_size = 5 - - total_scale = 1 - for upsample_scale in upsample_scales: - total_scale *= upsample_scale - - model = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) - - def data_gen(): - x = torch.rand(n_batch, n_freq, n_time) - return dict(specgram=x) - - trace_and_compare(model, data_gen, need_meta=False, need_concrete=False) - - -def test_wavernn_waveform(): - upsample_scales = [2, 2, 5] - n_rnn = 16 - n_fc = 16 - n_classes = 10 - hop_length = 20 - n_batch = 2 - n_time = 20 - n_freq = 10 - n_output = 16 - n_res_block = 3 - n_hidden = 16 - kernel_size = 5 - - model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block, n_rnn, n_fc, kernel_size, n_freq, n_hidden, - n_output) - - def data_gen(): - x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1)) - mels = torch.rand(n_batch, 1, n_freq, n_time) - return dict(waveform=x, specgram=mels) - - trace_and_compare(model, data_gen, need_meta=True, need_concrete=False) - - -def test_convtasnet_config(): - batch_size = 32 - num_frames = 800 - - model = ConvTasNet() - - def data_gen(): - tensor = torch.rand(batch_size, 1, num_frames) - return dict(input=tensor) - - trace_and_compare(model, data_gen, need_meta=True, need_concrete=False) - - -def test_deepspeech(): - n_batch = 2 - n_feature = 1 - n_channel = 1 - n_class = 40 - n_time = 32 - - model = DeepSpeech(n_feature=n_feature, n_class=n_class) - - def data_gen(): - x = torch.rand(n_batch, n_channel, n_time, n_feature) - return dict(x=x) - - trace_and_compare(model, data_gen, need_meta=False, need_concrete=False) - - -if __name__ == '__main__': - TEST_LIST = [ - test_wave2letter_waveform, - test_wave2letter_mfcc, - test_melresnet_waveform, - test_upsample_network_waveform, - test_wavernn_waveform, - test_convtasnet_config, - test_deepspeech, - ] - - for test_fn in TEST_LIST: - test_fn() diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py new file mode 100644 index 000000000..bf6c7ae55 --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py @@ -0,0 +1,22 @@ +import re + +import torch +from torchaudio_utils import trace_and_compare + +from tests.kit.model_zoo import model_zoo + + +def test_torchaudio_models(): + torch.backends.cudnn.deterministic = True + + sub_model_zoo = model_zoo.get_sub_registry('torchaudio') + + for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + # FIXME(ver217): temporarily skip these models + if re.search(f'(conformer|emformer|tacotron|wav2vec2_base|hubert_base)', name): + continue + model = model_fn() + trace_and_compare(model, + data_gen_fn, + output_transform_fn, + need_meta=(attribute is not None and attribute.has_control_flow)) diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py deleted file mode 100644 index 2073c4689..000000000 --- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py +++ /dev/null @@ -1,57 +0,0 @@ -import torch -from torchaudio.models import Tacotron2 -from torchaudio_utils import trace_and_compare -import pytest - - -def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5): - return Tacotron2( - mask_padding=False, - n_mels=n_mels, - n_symbol=20, - n_frames_per_step=1, - symbol_embedding_dim=32, - encoder_embedding_dim=32, - encoder_n_convolution=3, - encoder_kernel_size=5, - decoder_rnn_dim=32, - decoder_max_step=decoder_max_step, - decoder_dropout=0.1, - decoder_early_stopping=True, - attention_rnn_dim=32, - attention_hidden_dim=32, - attention_location_n_filter=32, - attention_location_kernel_size=31, - attention_dropout=0.1, - prenet_dim=32, - postnet_n_convolution=5, - postnet_kernel_size=5, - postnet_embedding_dim=512, - gate_threshold=gate_threshold, - ) - - -@pytest.mark.skip("Tracing failed") -def test_tacotron_model(): - n_mels = 80 - n_batch = 3 - max_mel_specgram_length = 300 - max_text_length = 100 - - model = _get_tacotron2_model(n_mels) - - def data_gen(): - text = torch.randint(0, 148, (n_batch, max_text_length)) - text_lengths = max_text_length * torch.ones((n_batch,)) - mel_specgram = torch.rand(n_batch, n_mels, max_mel_specgram_length) - mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,)) - return dict(tokens=text, - token_lengths=text_lengths, - mel_specgram=mel_specgram, - mel_specgram_lengths=mel_specgram_lengths) - - trace_and_compare(model, data_gen, need_meta=True, need_concrete=False) - - -if __name__ == "__main__": - test_tacotron_model() diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py deleted file mode 100644 index fbe24a8cd..000000000 --- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py +++ /dev/null @@ -1,67 +0,0 @@ -import torch -from torchaudio_utils import trace_and_compare -from torchaudio.models import Emformer, Conformer -import pytest - - -def test_conformer(): - input_dim = 80 - batch_size = 10 - num_frames = 400 - num_heads = 4 - ffn_dim = 128 - num_layers = 4 - depthwise_conv_kernel_size = 31 - - model = Conformer( - input_dim=input_dim, - num_heads=num_heads, - ffn_dim=ffn_dim, - num_layers=num_layers, - depthwise_conv_kernel_size=depthwise_conv_kernel_size, - ) - - def data_gen(): - lengths = torch.randint(1, num_frames, (batch_size,)) - input = torch.rand(batch_size, int(lengths.max()), input_dim) - return dict(input=input, lengths=lengths) - - def kwargs_transform(data): - new_data = {} - - for k, v in data.items(): - new_data[f'{k}_1'] = v - return new_data - - trace_and_compare(model, data_gen, need_meta=False, need_concrete=True, kwargs_transform=kwargs_transform) - - -@pytest.mark.skip("Tracing failed") -def test_emformer(): - input_dim = 128 - batch_size = 10 - num_heads = 8 - ffn_dim = 256 - num_layers = 3 - segment_length = 4 - num_frames = 400 - right_context_length = 1 - - model = Emformer(input_dim, num_heads, ffn_dim, num_layers, segment_length, right_context_length) - - def data_gen(): - lengths = torch.randint(1, num_frames, (batch_size,)) - input = torch.rand(batch_size, num_frames, input_dim) - return dict(input=input, lengths=lengths) - - trace_and_compare(model, data_gen, need_meta=True, need_concrete=False) - - -@pytest.mark.skip -def test_torchaudio_transformers(): - test_conformer() - test_emformer() - - -if __name__ == "__main__": - test_torchaudio_transformers() diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py deleted file mode 100644 index e8729b83f..000000000 --- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py +++ /dev/null @@ -1,50 +0,0 @@ -import torch -from torchaudio.models.wav2vec2 import ( - hubert_base, - hubert_large, - hubert_xlarge, - wav2vec2_base, - wav2vec2_large, - wav2vec2_large_lv60k, -) -from torchaudio_utils import trace_and_compare -import pytest - -MODEL_LIST = [ - hubert_base, - hubert_large, - hubert_xlarge, - wav2vec2_base, - wav2vec2_large, - wav2vec2_large_lv60k, -] - - -def _smoke_test(model, device): - model = model.to(device=device) - - batch_size, num_frames = 3, 1024 - - def data_gen(): - waveforms = torch.randn(batch_size, num_frames, device=device) - lengths = torch.randint( - low=0, - high=num_frames, - size=[ - batch_size, - ], - device=device, - ) - return dict(waveforms=waveforms, lengths=lengths) - - trace_and_compare(model, data_gen, need_meta=True, need_concrete=False) - - -@pytest.mark.skip("Tracing failed") -def test_wav2vec(): - for model_fn in MODEL_LIST: - _smoke_test(model_fn(), 'cpu') - - -if __name__ == "__main__": - test_wav2vec() diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py index 702c5f8f6..18d86fc05 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py @@ -3,7 +3,7 @@ import torch from colossalai.fx import symbolic_trace -def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False): +def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, need_concrete=False): data = data_gen() concrete_args = data if need_concrete else {} meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {} @@ -14,16 +14,15 @@ def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwa with torch.no_grad(): non_fx_out = model(**data) - - if kwargs_transform: - data = kwargs_transform(data) - fx_out = gm(**data) - if isinstance(fx_out, tuple): - for non_fx, fx in zip(non_fx_out, fx_out): - assert torch.allclose( - non_fx, fx, atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' - else: - assert torch.allclose( - fx_out, non_fx_out, - atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + + # compare output + transformed_fx_out = output_transform_fn(fx_out) + transformed_non_fx_out = output_transform_fn(non_fx_out) + + assert len(transformed_fx_out) == len(transformed_non_fx_out) + + for key, fx_output_val in transformed_fx_out.items(): + non_fx_output_val = transformed_non_fx_out[key] + assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ + f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'