mirror of https://github.com/hpcaitech/ColossalAI
131 lines
5.3 KiB
Python
131 lines
5.3 KiB
Python
from functools import partial
|
|
|
|
import torch
|
|
import torchaudio.models as tm
|
|
|
|
from ..registry import ModelAttribute, model_zoo
|
|
|
|
INPUT_DIM = 80
|
|
IN_FEATURES = 16
|
|
N_TIME = 20
|
|
KERNEL_SIZE = 5
|
|
HOP_LENGTH = 20
|
|
N_CLASSES = 10
|
|
N_FREQ = 16
|
|
N_MELS = 80
|
|
|
|
|
|
def conformer_data_gen_fn():
|
|
lengths = torch.randint(1, 400, (4,))
|
|
input = torch.rand(4, int(lengths.max()), INPUT_DIM)
|
|
return dict(input=input, lengths=lengths)
|
|
|
|
|
|
transformer_output_transform_fn = lambda outputs: dict(frames=outputs[0], lengths=outputs[1])
|
|
|
|
model_zoo.register(name='torchaudio_conformer',
|
|
model_fn=lambda: tm.Conformer(
|
|
input_dim=INPUT_DIM, num_heads=4, ffn_dim=128, num_layers=4, depthwise_conv_kernel_size=31),
|
|
data_gen_fn=conformer_data_gen_fn,
|
|
output_transform_fn=transformer_output_transform_fn)
|
|
|
|
single_output_transform_fn = lambda output: dict(output=output)
|
|
|
|
model_zoo.register(name='torchaudio_convtasnet',
|
|
model_fn=tm.ConvTasNet,
|
|
data_gen_fn=lambda: dict(input=torch.rand(4, 1, 8)),
|
|
output_transform_fn=single_output_transform_fn,
|
|
model_attribute=ModelAttribute(has_control_flow=True))
|
|
|
|
model_zoo.register(name='torchaudio_deepspeech',
|
|
model_fn=lambda: tm.DeepSpeech(IN_FEATURES, n_hidden=128, n_class=4),
|
|
data_gen_fn=lambda: dict(x=torch.rand(4, 1, 10, IN_FEATURES)),
|
|
output_transform_fn=single_output_transform_fn)
|
|
|
|
|
|
def emformer_data_gen_fn():
|
|
input = torch.rand(4, 400, IN_FEATURES)
|
|
lengths = torch.randint(1, 200, (4,))
|
|
return dict(input=input, lengths=lengths)
|
|
|
|
|
|
model_zoo.register(
|
|
name='torchaudio_emformer',
|
|
model_fn=lambda: tm.Emformer(input_dim=IN_FEATURES, num_heads=4, ffn_dim=128, num_layers=4, segment_length=4),
|
|
data_gen_fn=emformer_data_gen_fn,
|
|
output_transform_fn=transformer_output_transform_fn,
|
|
model_attribute=ModelAttribute(has_control_flow=True))
|
|
|
|
model_zoo.register(name='torchaudio_wav2letter_waveform',
|
|
model_fn=lambda: tm.Wav2Letter(input_type='waveform', num_features=40),
|
|
data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)),
|
|
output_transform_fn=single_output_transform_fn)
|
|
|
|
model_zoo.register(name='torchaudio_wav2letter_mfcc',
|
|
model_fn=lambda: tm.Wav2Letter(input_type='mfcc', num_features=40),
|
|
data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)),
|
|
output_transform_fn=single_output_transform_fn)
|
|
|
|
|
|
def wavernn_data_gen_fn():
|
|
waveform = torch.rand(4, 1, (N_TIME - KERNEL_SIZE + 1) * HOP_LENGTH)
|
|
specgram = torch.rand(4, 1, N_FREQ, N_TIME)
|
|
return dict(waveform=waveform, specgram=specgram)
|
|
|
|
|
|
model_zoo.register(name='torchaudio_wavernn',
|
|
model_fn=lambda: tm.WaveRNN(upsample_scales=[2, 2, 5],
|
|
n_classes=N_CLASSES,
|
|
hop_length=HOP_LENGTH,
|
|
kernel_size=KERNEL_SIZE,
|
|
n_freq=N_FREQ,
|
|
n_res_block=2,
|
|
n_rnn=64,
|
|
n_fc=64,
|
|
n_hidden=16,
|
|
n_output=16),
|
|
data_gen_fn=wavernn_data_gen_fn,
|
|
output_transform_fn=single_output_transform_fn,
|
|
model_attribute=ModelAttribute(has_control_flow=True))
|
|
|
|
|
|
def tacotron_data_gen_fn():
|
|
n_batch = 4
|
|
max_text_length = 100
|
|
max_mel_specgram_length = 300
|
|
tokens = torch.randint(0, 148, (n_batch, max_text_length))
|
|
token_lengths = max_text_length * torch.ones((n_batch,))
|
|
mel_specgram = torch.rand(n_batch, N_MELS, max_mel_specgram_length)
|
|
mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,))
|
|
return dict(tokens=tokens,
|
|
token_lengths=token_lengths,
|
|
mel_specgram=mel_specgram,
|
|
mel_specgram_lengths=mel_specgram_lengths)
|
|
|
|
|
|
model_zoo.register(name='torchaudio_tacotron',
|
|
model_fn=lambda: tm.Tacotron2(n_mels=N_MELS),
|
|
data_gen_fn=tacotron_data_gen_fn,
|
|
output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)),
|
|
model_attribute=ModelAttribute(has_control_flow=True))
|
|
|
|
|
|
def wav2vec_data_gen_fn():
|
|
batch_size, num_frames = 4, 400
|
|
waveforms = torch.randn(batch_size, num_frames)
|
|
lengths = torch.randint(0, num_frames, (batch_size,))
|
|
return dict(waveforms=waveforms, lengths=lengths)
|
|
|
|
|
|
model_zoo.register(name='torchaudio_wav2vec2_base',
|
|
model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0),
|
|
data_gen_fn=wav2vec_data_gen_fn,
|
|
output_transform_fn=transformer_output_transform_fn,
|
|
model_attribute=ModelAttribute(has_control_flow=True))
|
|
|
|
model_zoo.register(name='torchaudio_hubert_base',
|
|
model_fn=tm.hubert_base,
|
|
data_gen_fn=wav2vec_data_gen_fn,
|
|
output_transform_fn=transformer_output_transform_fn,
|
|
model_attribute=ModelAttribute(has_control_flow=True))
|