You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/kit/model_zoo/torchaudio/torchaudio.py

152 lines
4.5 KiB

from functools import partial
import torch
import torchaudio.models as tm
from ..registry import ModelAttribute, model_zoo
INPUT_DIM = 80
IN_FEATURES = 16
N_TIME = 20
KERNEL_SIZE = 5
HOP_LENGTH = 20
N_CLASSES = 10
N_FREQ = 16
N_MELS = 80
def conformer_data_gen_fn():
lengths = torch.randint(1, 400, (4,))
input = torch.rand(4, int(lengths.max()), INPUT_DIM)
return dict(input=input, lengths=lengths)
transformer_output_transform_fn = lambda outputs: dict(frames=outputs[0], lengths=outputs[1])
model_zoo.register(
name="torchaudio_conformer",
model_fn=lambda: tm.Conformer(
input_dim=INPUT_DIM, num_heads=4, ffn_dim=128, num_layers=4, depthwise_conv_kernel_size=31
),
data_gen_fn=conformer_data_gen_fn,
output_transform_fn=transformer_output_transform_fn,
)
single_output_transform_fn = lambda output: dict(output=output)
model_zoo.register(
name="torchaudio_convtasnet",
model_fn=tm.ConvTasNet,
data_gen_fn=lambda: dict(input=torch.rand(4, 1, 8)),
output_transform_fn=single_output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="torchaudio_deepspeech",
model_fn=lambda: tm.DeepSpeech(IN_FEATURES, n_hidden=128, n_class=4),
data_gen_fn=lambda: dict(x=torch.rand(4, 1, 10, IN_FEATURES)),
output_transform_fn=single_output_transform_fn,
)
def emformer_data_gen_fn():
input = torch.rand(4, 400, IN_FEATURES)
lengths = torch.randint(1, 200, (4,))
return dict(input=input, lengths=lengths)
model_zoo.register(
name="torchaudio_emformer",
model_fn=lambda: tm.Emformer(input_dim=IN_FEATURES, num_heads=4, ffn_dim=128, num_layers=4, segment_length=4),
data_gen_fn=emformer_data_gen_fn,
output_transform_fn=transformer_output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="torchaudio_wav2letter_waveform",
model_fn=lambda: tm.Wav2Letter(input_type="waveform", num_features=40),
data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)),
output_transform_fn=single_output_transform_fn,
)
model_zoo.register(
name="torchaudio_wav2letter_mfcc",
model_fn=lambda: tm.Wav2Letter(input_type="mfcc", num_features=40),
data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)),
output_transform_fn=single_output_transform_fn,
)
def wavernn_data_gen_fn():
waveform = torch.rand(4, 1, (N_TIME - KERNEL_SIZE + 1) * HOP_LENGTH)
specgram = torch.rand(4, 1, N_FREQ, N_TIME)
return dict(waveform=waveform, specgram=specgram)
model_zoo.register(
name="torchaudio_wavernn",
model_fn=lambda: tm.WaveRNN(
upsample_scales=[2, 2, 5],
n_classes=N_CLASSES,
hop_length=HOP_LENGTH,
kernel_size=KERNEL_SIZE,
n_freq=N_FREQ,
n_res_block=2,
n_rnn=64,
n_fc=64,
n_hidden=16,
n_output=16,
),
data_gen_fn=wavernn_data_gen_fn,
output_transform_fn=single_output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
def tacotron_data_gen_fn():
n_batch = 4
max_text_length = 100
max_mel_specgram_length = 300
tokens = torch.randint(0, 148, (n_batch, max_text_length))
token_lengths = max_text_length * torch.ones((n_batch,))
mel_specgram = torch.rand(n_batch, N_MELS, max_mel_specgram_length)
mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,))
return dict(
tokens=tokens, token_lengths=token_lengths, mel_specgram=mel_specgram, mel_specgram_lengths=mel_specgram_lengths
)
model_zoo.register(
name="torchaudio_tacotron",
model_fn=lambda: tm.Tacotron2(n_mels=N_MELS),
data_gen_fn=tacotron_data_gen_fn,
output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)),
model_attribute=ModelAttribute(has_control_flow=True),
)
def wav2vec_data_gen_fn():
batch_size, num_frames = 4, 400
waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(0, num_frames, (batch_size,))
return dict(waveforms=waveforms, lengths=lengths)
model_zoo.register(
name="torchaudio_wav2vec2_base",
model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0),
data_gen_fn=wav2vec_data_gen_fn,
output_transform_fn=transformer_output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="torchaudio_hubert_base",
model_fn=tm.hubert_base,
data_gen_fn=wav2vec_data_gen_fn,
output_transform_fn=transformer_output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)