from functools import partial import torch import torchaudio.models as tm from ..registry import ModelAttribute, model_zoo INPUT_DIM = 80 IN_FEATURES = 16 N_TIME = 20 KERNEL_SIZE = 5 HOP_LENGTH = 20 N_CLASSES = 10 N_FREQ = 16 N_MELS = 80 def conformer_data_gen_fn(): lengths = torch.randint(1, 400, (4,)) input = torch.rand(4, int(lengths.max()), INPUT_DIM) return dict(input=input, lengths=lengths) transformer_output_transform_fn = lambda outputs: dict(frames=outputs[0], lengths=outputs[1]) model_zoo.register( name="torchaudio_conformer", model_fn=lambda: tm.Conformer( input_dim=INPUT_DIM, num_heads=4, ffn_dim=128, num_layers=4, depthwise_conv_kernel_size=31 ), data_gen_fn=conformer_data_gen_fn, output_transform_fn=transformer_output_transform_fn, ) single_output_transform_fn = lambda output: dict(output=output) model_zoo.register( name="torchaudio_convtasnet", model_fn=tm.ConvTasNet, data_gen_fn=lambda: dict(input=torch.rand(4, 1, 8)), output_transform_fn=single_output_transform_fn, model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( name="torchaudio_deepspeech", model_fn=lambda: tm.DeepSpeech(IN_FEATURES, n_hidden=128, n_class=4), data_gen_fn=lambda: dict(x=torch.rand(4, 1, 10, IN_FEATURES)), output_transform_fn=single_output_transform_fn, ) def emformer_data_gen_fn(): input = torch.rand(4, 400, IN_FEATURES) lengths = torch.randint(1, 200, (4,)) return dict(input=input, lengths=lengths) model_zoo.register( name="torchaudio_emformer", model_fn=lambda: tm.Emformer(input_dim=IN_FEATURES, num_heads=4, ffn_dim=128, num_layers=4, segment_length=4), data_gen_fn=emformer_data_gen_fn, output_transform_fn=transformer_output_transform_fn, model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( name="torchaudio_wav2letter_waveform", model_fn=lambda: tm.Wav2Letter(input_type="waveform", num_features=40), data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)), output_transform_fn=single_output_transform_fn, ) model_zoo.register( name="torchaudio_wav2letter_mfcc", model_fn=lambda: tm.Wav2Letter(input_type="mfcc", num_features=40), data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)), output_transform_fn=single_output_transform_fn, ) def wavernn_data_gen_fn(): waveform = torch.rand(4, 1, (N_TIME - KERNEL_SIZE + 1) * HOP_LENGTH) specgram = torch.rand(4, 1, N_FREQ, N_TIME) return dict(waveform=waveform, specgram=specgram) model_zoo.register( name="torchaudio_wavernn", model_fn=lambda: tm.WaveRNN( upsample_scales=[2, 2, 5], n_classes=N_CLASSES, hop_length=HOP_LENGTH, kernel_size=KERNEL_SIZE, n_freq=N_FREQ, n_res_block=2, n_rnn=64, n_fc=64, n_hidden=16, n_output=16, ), data_gen_fn=wavernn_data_gen_fn, output_transform_fn=single_output_transform_fn, model_attribute=ModelAttribute(has_control_flow=True), ) def tacotron_data_gen_fn(): n_batch = 4 max_text_length = 100 max_mel_specgram_length = 300 tokens = torch.randint(0, 148, (n_batch, max_text_length)) token_lengths = max_text_length * torch.ones((n_batch,)) mel_specgram = torch.rand(n_batch, N_MELS, max_mel_specgram_length) mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,)) return dict( tokens=tokens, token_lengths=token_lengths, mel_specgram=mel_specgram, mel_specgram_lengths=mel_specgram_lengths ) model_zoo.register( name="torchaudio_tacotron", model_fn=lambda: tm.Tacotron2(n_mels=N_MELS), data_gen_fn=tacotron_data_gen_fn, output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)), model_attribute=ModelAttribute(has_control_flow=True), ) def wav2vec_data_gen_fn(): batch_size, num_frames = 4, 400 waveforms = torch.randn(batch_size, num_frames) lengths = torch.randint(0, num_frames, (batch_size,)) return dict(waveforms=waveforms, lengths=lengths) model_zoo.register( name="torchaudio_wav2vec2_base", model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0), data_gen_fn=wav2vec_data_gen_fn, output_transform_fn=transformer_output_transform_fn, model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( name="torchaudio_hubert_base", model_fn=tm.hubert_base, data_gen_fn=wav2vec_data_gen_fn, output_transform_fn=transformer_output_transform_fn, model_attribute=ModelAttribute(has_control_flow=True), )