mirror of https://github.com/hpcaitech/ColossalAI
[fx] add torchaudio test (#1369)
* [fx]add torchaudio test * [fx]add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test and test patches * Delete ~ * [fx] add patches and patches test * [fx] add patches and patches test * [fx] fix patches * [fx] fix rnn patches * [fx] fix rnn patches * [fx] fix rnn patches * [fx] fix rnn patches * [fx] merge upstream * [fx] fix import errorspull/1377/head
parent
fb6f085907
commit
be229217ce
|
@ -108,6 +108,27 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None):
|
|||
return torch.empty(final_shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.repeat_interleave)
|
||||
def torch_repeat_interleave(input, repeats, dim=None, output_size=None):
|
||||
assert isinstance(repeats, int) or isinstance(repeats, torch.Tensor), \
|
||||
"Argument 'repeats' should be of type 'torch.Tensor' or 'int'"
|
||||
|
||||
shape = list(input.shape) if dim is not None else [input.numel()]
|
||||
dim = dim if dim is not None else 0
|
||||
dim = input.dim() + dim if dim < 0 else dim
|
||||
|
||||
if isinstance(repeats, int):
|
||||
shape[dim] = shape[dim] * repeats
|
||||
elif isinstance(repeats, torch.Tensor):
|
||||
shape[dim] = repeats.sum()
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.Tensor.repeat_interleave)
|
||||
def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None):
|
||||
return torch_repeat_interleave(self, repeats, dim, output_size)
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.roll)
|
||||
def torch_roll(input, shifts, dims=None):
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
|
|
@ -3,4 +3,5 @@ from .convolution import *
|
|||
from .embedding import *
|
||||
from .linear import *
|
||||
from .normalization import *
|
||||
from .pooling import *
|
||||
from .pooling import *
|
||||
from .rnn import *
|
|
@ -7,5 +7,6 @@ from ..registry import meta_patched_module
|
|||
@meta_patched_module.register(torch.nn.GELU)
|
||||
@meta_patched_module.register(torch.nn.Tanh)
|
||||
@meta_patched_module.register(torch.nn.ReLU6)
|
||||
@meta_patched_module.register(torch.nn.PReLU)
|
||||
def torch_nn_non_linear_act(self, input):
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
|
|
@ -55,3 +55,60 @@ def torch_nn_conv3d(self, input):
|
|||
w_out,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
@meta_patched_module.register(torch.nn.ConvTranspose1d)
|
||||
def torch_nn_convtranspose1d(self, input):
|
||||
# the output shape is calculated using the formula stated
|
||||
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
|
||||
l_in = input.shape[-1]
|
||||
c_out = self.out_channels
|
||||
l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] +
|
||||
self.dilation[0] * (self.kernel_size[0] - 1) +
|
||||
self.output_padding[0] + 1)
|
||||
result_shape = input.shape[:-2] + (
|
||||
c_out,
|
||||
l_out,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
@meta_patched_module.register(torch.nn.ConvTranspose2d)
|
||||
def torch_nn_convtranspose2d(self, input):
|
||||
# the output shape is calculated using the formula stated
|
||||
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
||||
h_in, w_in = input.shape[-2:]
|
||||
c_out = self.out_channels
|
||||
h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] +
|
||||
self.dilation[0] * (self.kernel_size[0] - 1) +
|
||||
self.output_padding[0] + 1)
|
||||
w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] +
|
||||
self.dilation[1] * (self.kernel_size[1] - 1) +
|
||||
self.output_padding[1] + 1)
|
||||
result_shape = input.shape[:-3] + (
|
||||
c_out,
|
||||
h_out,
|
||||
w_out,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
@meta_patched_module.register(torch.nn.ConvTranspose3d)
|
||||
def torch_nn_convtranspose3d(self, input):
|
||||
# the output shape is calculated using the formula stated
|
||||
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
|
||||
d_in, h_in, w_in = input.shape[-3:]
|
||||
c_out = self.out_channels
|
||||
d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] +
|
||||
self.dilation[0] * (self.kernel_size[0] - 1) +
|
||||
self.output_padding[0] + 1)
|
||||
h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] +
|
||||
self.dilation[1] * (self.kernel_size[1] - 1) +
|
||||
self.output_padding[1] + 1)
|
||||
w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] +
|
||||
self.dilation[2] * (self.kernel_size[2] - 1) +
|
||||
self.output_padding[2] + 1)
|
||||
result_shape = input.shape[:-4] + (
|
||||
c_out,
|
||||
d_out,
|
||||
h_out,
|
||||
w_out,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
|
@ -6,4 +6,4 @@ from ..registry import meta_patched_module
|
|||
def torch_nn_linear(self, input):
|
||||
last_dim = input.shape[-1]
|
||||
assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch'
|
||||
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
|
||||
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
import torch
|
||||
from ..registry import meta_patched_module
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.GRU)
|
||||
@meta_patched_module.register(torch.nn.RNN)
|
||||
def torch_nn_rnn(self, input, hx):
|
||||
assert input.shape[
|
||||
-1] == self.input_size, f'Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch'
|
||||
assert hx.shape[
|
||||
-1] == self.hidden_size, f'Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch'
|
||||
d = 2 if self.bidirectional else 1
|
||||
return torch.empty(input.shape[:-1] + (self.hidden_size * d,), device="meta"), hx
|
|
@ -27,7 +27,7 @@ def save_checkpoint(dire: str,
|
|||
# save the dist context about the tensors in a new dict, while still maintain the original dict.
|
||||
for k, v in model_state.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
gather_tensor(v) # gather shared tensors to rank0
|
||||
gather_tensor(v) # gather shared tensors to rank0
|
||||
# don't recover tensors in rank0, since the dict is only a copy of model
|
||||
|
||||
if rank == 0:
|
||||
|
|
|
@ -34,7 +34,7 @@ def gather_tensor(colo_tensor: ColoTensor) -> None:
|
|||
dist.barrier()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
setattr(colo_tensor, 'save_ready', True) # set saving signitrue
|
||||
setattr(colo_tensor, 'save_ready', True) # set saving signitrue
|
||||
|
||||
|
||||
def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
|
||||
|
@ -54,9 +54,8 @@ def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
|
|||
if dist.get_rank() == 0:
|
||||
colo_tensor.set_dist_spec(dist_spec)
|
||||
else:
|
||||
rep_tensor = ColoTensor(entire_data, ColoTensorSpec(
|
||||
pg=colo_tensor.get_process_group(),
|
||||
compute_attr=colo_tensor.compute_spec))
|
||||
rep_tensor = ColoTensor(
|
||||
entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec))
|
||||
rep_tensor.set_dist_spec(dist_spec)
|
||||
with torch.no_grad():
|
||||
colo_tensor.data.copy_(rep_tensor.data)
|
||||
|
|
|
@ -3,4 +3,5 @@ torchvision
|
|||
transformers
|
||||
timm
|
||||
titans
|
||||
torchaudio
|
||||
torchrec
|
||||
|
|
4
setup.py
4
setup.py
|
@ -100,7 +100,7 @@ def get_version():
|
|||
version += f'+torch{torch_version}cu{cuda_version}'
|
||||
return version
|
||||
|
||||
|
||||
|
||||
if build_cuda_ext:
|
||||
try:
|
||||
import torch
|
||||
|
@ -115,7 +115,7 @@ if build_cuda_ext:
|
|||
except ImportError:
|
||||
print('torch is not found. CUDA extension will not be installed')
|
||||
build_cuda_ext = False
|
||||
|
||||
|
||||
if build_cuda_ext:
|
||||
build_cuda_ext = check_cuda_availability(CUDA_HOME) and check_cuda_torch_binary_vs_bare_metal(CUDA_HOME)
|
||||
|
||||
|
|
|
@ -4,7 +4,12 @@ from colossalai.fx.tracer.meta_patch import patched_module
|
|||
|
||||
def _run(data, module, patch_fn):
|
||||
try:
|
||||
output = patch_fn(module, data)
|
||||
if isinstance(data, dict):
|
||||
output = patch_fn(module, **data)
|
||||
if isinstance(data, tuple) or isinstance(data, list):
|
||||
output = patch_fn(module, *data)
|
||||
else:
|
||||
output = patch_fn(module, data)
|
||||
return output
|
||||
except Exception as e:
|
||||
return e
|
||||
|
@ -17,8 +22,13 @@ def _assert_output_shape(data, module, patch_fn, expect_exception, output_shape)
|
|||
assert isinstance(output, AssertionError)
|
||||
else:
|
||||
assert not isinstance(output, Exception)
|
||||
assert output.is_meta
|
||||
assert output.shape == output_shape
|
||||
if isinstance(output, tuple):
|
||||
for item, shape in zip(output, output_shape):
|
||||
assert item.is_meta
|
||||
assert item.shape == shape
|
||||
else:
|
||||
assert output.is_meta
|
||||
assert output.shape == output_shape
|
||||
|
||||
|
||||
def test_linear():
|
||||
|
@ -27,11 +37,27 @@ def test_linear():
|
|||
module = torch.nn.Linear(4, 2)
|
||||
_assert_output_shape(data, module, patched_module.torch_nn_linear, False, torch.Size([2, 2]))
|
||||
|
||||
# Test if the linear patch can catch exception when dimension does not match
|
||||
# test if the linear patch can catch exception when dimension does not match
|
||||
data = torch.rand(2, 2, device='meta')
|
||||
_assert_output_shape(data, module, patched_module.torch_nn_linear, True, None)
|
||||
|
||||
|
||||
def test_rnn():
|
||||
# test rnn patch can produce the meta output with correct shape
|
||||
data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20))
|
||||
module = torch.nn.RNN(10, 20, 2)
|
||||
output, hn = module(*data)
|
||||
meta_data = (torch.randn(5, 3, 10).to('meta'), torch.randn(2, 3, 20).to('meta'))
|
||||
_assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, False, (output.shape, hn.shape))
|
||||
|
||||
# test if the rnn patch can catch exception when dimension does not match
|
||||
data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20))
|
||||
module = torch.nn.RNN(10, 20, 2)
|
||||
output, hn = module(*data)
|
||||
meta_data = (torch.randn(5, 3, 1).to('meta'), torch.randn(2, 3, 20).to('meta'))
|
||||
_assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, True, None)
|
||||
|
||||
|
||||
def test_embedding():
|
||||
data = torch.rand(2, 4, device='meta')
|
||||
|
||||
|
@ -146,7 +172,7 @@ def test_conv1d():
|
|||
|
||||
|
||||
def test_conv2d():
|
||||
# test conv 1d
|
||||
# test conv 2d
|
||||
data = torch.rand(2, 3, 4, 4)
|
||||
conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2)
|
||||
materialized_output = conv2d(data)
|
||||
|
@ -187,7 +213,7 @@ def test_conv2d():
|
|||
|
||||
|
||||
def test_conv3d():
|
||||
# test conv 1d
|
||||
# test conv 3d
|
||||
data = torch.rand(2, 3, 4, 4, 4)
|
||||
conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2)
|
||||
materialized_output = conv3d(data)
|
||||
|
@ -227,6 +253,75 @@ def test_conv3d():
|
|||
output_shape=materialized_output.shape)
|
||||
|
||||
|
||||
def test_conv_transpose1d():
|
||||
# test conv transpose1d
|
||||
data = torch.rand(2, 3, 4)
|
||||
|
||||
convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2)
|
||||
materialized_output = convtrans1d(data)
|
||||
meta_data = data.to('meta')
|
||||
_assert_output_shape(data=meta_data,
|
||||
module=convtrans1d,
|
||||
patch_fn=patched_module.torch_nn_convtranspose1d,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1)
|
||||
materialized_output = convtrans1d(data)
|
||||
meta_data = data.to('meta')
|
||||
_assert_output_shape(data=meta_data,
|
||||
module=convtrans1d,
|
||||
patch_fn=patched_module.torch_nn_convtranspose1d,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
|
||||
def test_conv_transpose2d():
|
||||
# test conv transpose2d
|
||||
data = torch.rand(2, 3, 4, 4)
|
||||
|
||||
convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2)
|
||||
materialized_output = convtrans2d(data)
|
||||
meta_data = data.to('meta')
|
||||
_assert_output_shape(data=meta_data,
|
||||
module=convtrans2d,
|
||||
patch_fn=patched_module.torch_nn_convtranspose2d,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1)
|
||||
materialized_output = convtrans2d(data)
|
||||
meta_data = data.to('meta')
|
||||
_assert_output_shape(data=meta_data,
|
||||
module=convtrans2d,
|
||||
patch_fn=patched_module.torch_nn_convtranspose2d,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
|
||||
def test_conv_transpose3d():
|
||||
# test conv transpose2d
|
||||
data = torch.rand(2, 3, 4, 4, 4)
|
||||
|
||||
convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2)
|
||||
materialized_output = convtrans3d(data)
|
||||
meta_data = data.to('meta')
|
||||
_assert_output_shape(data=meta_data,
|
||||
module=convtrans3d,
|
||||
patch_fn=patched_module.torch_nn_convtranspose3d,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1)
|
||||
materialized_output = convtrans3d(data)
|
||||
meta_data = data.to('meta')
|
||||
_assert_output_shape(data=meta_data,
|
||||
module=convtrans3d,
|
||||
patch_fn=patched_module.torch_nn_convtranspose3d,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
|
||||
def test_pool1d():
|
||||
combinations = [[torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d],
|
||||
[torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d]]
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
import torch
|
||||
from colossalai.fx.tracer.meta_patch import patched_function
|
||||
from functools import partial
|
||||
|
||||
|
||||
def _run(data, patch_fn):
|
||||
try:
|
||||
output = patch_fn(data)
|
||||
return output
|
||||
except Exception as e:
|
||||
return e
|
||||
|
||||
|
||||
def _assert_output_shape(data, patch_fn, expect_exception, output_shape):
|
||||
output = _run(data, patch_fn)
|
||||
|
||||
if expect_exception:
|
||||
assert isinstance(output, AssertionError)
|
||||
else:
|
||||
assert not isinstance(output, Exception)
|
||||
assert output.is_meta
|
||||
assert output.shape == output_shape
|
||||
|
||||
|
||||
def test_repeat_interleave():
|
||||
patch_fn = patched_function.torch_repeat_interleave
|
||||
|
||||
# examples from https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html
|
||||
data = torch.tensor([1, 2, 3])
|
||||
materialized_output = torch.repeat_interleave(data, repeats=2)
|
||||
repeat_interleave = partial(patch_fn, repeats=2)
|
||||
meta_data = data.to('meta')
|
||||
_assert_output_shape(data=meta_data,
|
||||
patch_fn=repeat_interleave,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
data = torch.tensor([[1, 2], [3, 4]])
|
||||
materialized_output = torch.repeat_interleave(data, repeats=3, dim=1)
|
||||
repeat_interleave = partial(patch_fn, repeats=3, dim=1)
|
||||
meta_data = data.to('meta')
|
||||
_assert_output_shape(data=meta_data,
|
||||
patch_fn=repeat_interleave,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
data = torch.tensor([[1, 2], [3, 4]])
|
||||
materialized_output = torch.repeat_interleave(data, repeats=torch.tensor([1, 2]), dim=-1)
|
||||
repeat_interleave = partial(patch_fn, repeats=torch.tensor([1, 2]), dim=-1)
|
||||
meta_data = data.to('meta')
|
||||
_assert_output_shape(data=meta_data,
|
||||
patch_fn=repeat_interleave,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
data = torch.tensor([[1, 2], [3, 4]])
|
||||
materialized_output = torch.repeat_interleave(data, repeats=torch.tensor([1, 2]), dim=0)
|
||||
repeat_interleave = partial(patch_fn, repeats=[1, 2], dim=0)
|
||||
meta_data = data.to('meta')
|
||||
_assert_output_shape(data=meta_data,
|
||||
patch_fn=repeat_interleave,
|
||||
expect_exception=True,
|
||||
output_shape=materialized_output.shape)
|
|
@ -22,7 +22,7 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None):
|
|||
with torch.no_grad():
|
||||
fx_out = gm(data)
|
||||
non_fx_out = model(data)
|
||||
|
||||
|
||||
# compare output
|
||||
if isinstance(fx_out, tuple):
|
||||
# some models produce tuple as output
|
||||
|
|
|
@ -0,0 +1,145 @@
|
|||
import torch
|
||||
from torchaudio_utils import trace_and_compare
|
||||
from torchaudio.models import ConvTasNet, DeepSpeech, Wav2Letter, WaveRNN
|
||||
from torchaudio.models.wavernn import MelResNet, UpsampleNetwork
|
||||
import pytest
|
||||
|
||||
|
||||
def test_wave2letter_waveform():
|
||||
batch_size = 2
|
||||
num_features = 1
|
||||
num_classes = 40
|
||||
input_length = 320
|
||||
|
||||
model = Wav2Letter(num_classes=num_classes, num_features=num_features)
|
||||
|
||||
def data_gen():
|
||||
x = torch.rand(batch_size, num_features, input_length)
|
||||
return dict(x=x)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
|
||||
|
||||
|
||||
def test_wave2letter_mfcc():
|
||||
batch_size = 2
|
||||
num_features = 13
|
||||
num_classes = 40
|
||||
input_length = 2
|
||||
|
||||
model = Wav2Letter(num_classes=num_classes, input_type="mfcc", num_features=num_features)
|
||||
|
||||
def data_gen():
|
||||
x = torch.rand(batch_size, num_features, input_length)
|
||||
return dict(x=x)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
|
||||
|
||||
|
||||
def test_melresnet_waveform():
|
||||
n_batch = 2
|
||||
n_time = 200
|
||||
n_freq = 100
|
||||
n_output = 128
|
||||
n_res_block = 10
|
||||
n_hidden = 128
|
||||
kernel_size = 5
|
||||
|
||||
model = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
|
||||
|
||||
def data_gen():
|
||||
x = torch.rand(n_batch, n_freq, n_time)
|
||||
return dict(specgram=x)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
|
||||
|
||||
|
||||
def test_upsample_network_waveform():
|
||||
upsample_scales = [5, 5, 8]
|
||||
n_batch = 2
|
||||
n_time = 200
|
||||
n_freq = 100
|
||||
n_output = 64
|
||||
n_res_block = 10
|
||||
n_hidden = 32
|
||||
kernel_size = 5
|
||||
|
||||
total_scale = 1
|
||||
for upsample_scale in upsample_scales:
|
||||
total_scale *= upsample_scale
|
||||
|
||||
model = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
|
||||
|
||||
def data_gen():
|
||||
x = torch.rand(n_batch, n_freq, n_time)
|
||||
return dict(specgram=x)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
|
||||
|
||||
|
||||
def test_wavernn_waveform():
|
||||
upsample_scales = [2, 2, 5]
|
||||
n_rnn = 16
|
||||
n_fc = 16
|
||||
n_classes = 10
|
||||
hop_length = 20
|
||||
n_batch = 2
|
||||
n_time = 20
|
||||
n_freq = 10
|
||||
n_output = 16
|
||||
n_res_block = 3
|
||||
n_hidden = 16
|
||||
kernel_size = 5
|
||||
|
||||
model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block, n_rnn, n_fc, kernel_size, n_freq, n_hidden,
|
||||
n_output)
|
||||
|
||||
def data_gen():
|
||||
x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
|
||||
mels = torch.rand(n_batch, 1, n_freq, n_time)
|
||||
return dict(waveform=x, specgram=mels)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
|
||||
|
||||
|
||||
def test_convtasnet_config():
|
||||
batch_size = 32
|
||||
num_frames = 800
|
||||
|
||||
model = ConvTasNet()
|
||||
|
||||
def data_gen():
|
||||
tensor = torch.rand(batch_size, 1, num_frames)
|
||||
return dict(input=tensor)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
|
||||
|
||||
|
||||
def test_deepspeech():
|
||||
n_batch = 2
|
||||
n_feature = 1
|
||||
n_channel = 1
|
||||
n_class = 40
|
||||
n_time = 32
|
||||
|
||||
model = DeepSpeech(n_feature=n_feature, n_class=n_class)
|
||||
|
||||
def data_gen():
|
||||
x = torch.rand(n_batch, n_channel, n_time, n_feature)
|
||||
return dict(x=x)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
TEST_LIST = [
|
||||
test_wave2letter_waveform,
|
||||
test_wave2letter_mfcc,
|
||||
test_melresnet_waveform,
|
||||
test_upsample_network_waveform,
|
||||
test_wavernn_waveform,
|
||||
test_convtasnet_config,
|
||||
test_deepspeech,
|
||||
]
|
||||
|
||||
for test_fn in TEST_LIST:
|
||||
test_fn()
|
|
@ -0,0 +1,57 @@
|
|||
import torch
|
||||
from torchaudio.models import Tacotron2
|
||||
from torchaudio_utils import trace_and_compare
|
||||
import pytest
|
||||
|
||||
|
||||
def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5):
|
||||
return Tacotron2(
|
||||
mask_padding=False,
|
||||
n_mels=n_mels,
|
||||
n_symbol=20,
|
||||
n_frames_per_step=1,
|
||||
symbol_embedding_dim=32,
|
||||
encoder_embedding_dim=32,
|
||||
encoder_n_convolution=3,
|
||||
encoder_kernel_size=5,
|
||||
decoder_rnn_dim=32,
|
||||
decoder_max_step=decoder_max_step,
|
||||
decoder_dropout=0.1,
|
||||
decoder_early_stopping=True,
|
||||
attention_rnn_dim=32,
|
||||
attention_hidden_dim=32,
|
||||
attention_location_n_filter=32,
|
||||
attention_location_kernel_size=31,
|
||||
attention_dropout=0.1,
|
||||
prenet_dim=32,
|
||||
postnet_n_convolution=5,
|
||||
postnet_kernel_size=5,
|
||||
postnet_embedding_dim=512,
|
||||
gate_threshold=gate_threshold,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_tacotron_model():
|
||||
n_mels = 80
|
||||
n_batch = 3
|
||||
max_mel_specgram_length = 300
|
||||
max_text_length = 100
|
||||
|
||||
model = _get_tacotron2_model(n_mels)
|
||||
|
||||
def data_gen():
|
||||
text = torch.randint(0, 148, (n_batch, max_text_length))
|
||||
text_lengths = max_text_length * torch.ones((n_batch,))
|
||||
mel_specgram = torch.rand(n_batch, n_mels, max_mel_specgram_length)
|
||||
mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,))
|
||||
return dict(tokens=text,
|
||||
token_lengths=text_lengths,
|
||||
mel_specgram=mel_specgram,
|
||||
mel_specgram_lengths=mel_specgram_lengths)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tacotron_model()
|
|
@ -0,0 +1,61 @@
|
|||
import torch
|
||||
from torchaudio_utils import trace_and_compare
|
||||
from torchaudio.models import Emformer, Conformer
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_conformer():
|
||||
input_dim = 80
|
||||
batch_size = 10
|
||||
num_frames = 400
|
||||
num_heads = 4
|
||||
ffn_dim = 128
|
||||
num_layers = 4
|
||||
depthwise_conv_kernel_size = 31
|
||||
|
||||
model = Conformer(
|
||||
input_dim=input_dim,
|
||||
num_heads=num_heads,
|
||||
ffn_dim=ffn_dim,
|
||||
num_layers=num_layers,
|
||||
depthwise_conv_kernel_size=depthwise_conv_kernel_size,
|
||||
)
|
||||
|
||||
def data_gen():
|
||||
lengths = torch.randint(1, num_frames, (batch_size,))
|
||||
input = torch.rand(batch_size, int(lengths.max()), input_dim)
|
||||
return dict(input=input, lengths=lengths)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=False, need_concrete=True)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_emformer():
|
||||
input_dim = 128
|
||||
batch_size = 10
|
||||
num_heads = 8
|
||||
ffn_dim = 256
|
||||
num_layers = 3
|
||||
segment_length = 4
|
||||
num_frames = 400
|
||||
right_context_length = 1
|
||||
|
||||
model = Emformer(input_dim, num_heads, ffn_dim, num_layers, segment_length, right_context_length)
|
||||
|
||||
def data_gen():
|
||||
lengths = torch.randint(1, num_frames, (batch_size,))
|
||||
input = torch.rand(batch_size, num_frames, input_dim)
|
||||
return dict(input=input, lengths=lengths)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_torchaudio_transformers():
|
||||
test_conformer()
|
||||
test_emformer()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_torchaudio_transformers()
|
|
@ -0,0 +1,50 @@
|
|||
import torch
|
||||
from torchaudio.models.wav2vec2 import (
|
||||
hubert_base,
|
||||
hubert_large,
|
||||
hubert_xlarge,
|
||||
wav2vec2_base,
|
||||
wav2vec2_large,
|
||||
wav2vec2_large_lv60k,
|
||||
)
|
||||
from torchaudio_utils import trace_and_compare
|
||||
import pytest
|
||||
|
||||
MODEL_LIST = [
|
||||
hubert_base,
|
||||
hubert_large,
|
||||
hubert_xlarge,
|
||||
wav2vec2_base,
|
||||
wav2vec2_large,
|
||||
wav2vec2_large_lv60k,
|
||||
]
|
||||
|
||||
|
||||
def _smoke_test(model, device):
|
||||
model = model.to(device=device)
|
||||
|
||||
batch_size, num_frames = 3, 1024
|
||||
|
||||
def data_gen():
|
||||
waveforms = torch.randn(batch_size, num_frames, device=device)
|
||||
lengths = torch.randint(
|
||||
low=0,
|
||||
high=num_frames,
|
||||
size=[
|
||||
batch_size,
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
return dict(waveforms=waveforms, lengths=lengths)
|
||||
|
||||
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_wav2vec():
|
||||
for model_fn in MODEL_LIST:
|
||||
_smoke_test(model_fn(), 'cpu')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_wav2vec()
|
|
@ -0,0 +1,28 @@
|
|||
from colossalai.fx import ColoTracer
|
||||
import torch
|
||||
from torch.fx import GraphModule, Tracer
|
||||
|
||||
|
||||
def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False):
|
||||
data = data_gen()
|
||||
concrete_args = data if need_concrete else {}
|
||||
meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {}
|
||||
tracer = ColoTracer()
|
||||
|
||||
graph = tracer.trace(root=model, concrete_args=concrete_args, meta_args=meta_args)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
model.eval()
|
||||
gm.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
non_fx_out = model(**data)
|
||||
fx_out = gm(**data)
|
||||
if isinstance(fx_out, tuple):
|
||||
for non_fx, fx in zip(non_fx_out, fx_out):
|
||||
assert torch.allclose(non_fx,
|
||||
fx), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||
else:
|
||||
assert torch.allclose(
|
||||
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
Loading…
Reference in New Issue