Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion test/torchaudio_unittest/models_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import itertools
from collections import namedtuple

import torch
from torchaudio.models import Wav2Letter, MelResNet, UpsampleNetwork, WaveRNN
from torchaudio.models import (
Wav2Letter,
MelResNet,
UpsampleNetwork,
WaveRNN,
ConvTasNet,
)
from parameterized import parameterized

from torchaudio_unittest import common_utils

Expand Down Expand Up @@ -115,3 +125,58 @@ def test_waveform(self):
out = model(x, mels)

assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes)


_ConvTasNetParams = namedtuple(
'_ConvTasNetParams',
[
'enc_num_feats',
'enc_kernel_size',
'msk_num_feats',
'msk_num_hidden_feats',
'msk_kernel_size',
'msk_num_layers',
'msk_num_stacks',
]
)


class TestConvTasNet(common_utils.TorchaudioTestCase):
@parameterized.expand(list(itertools.product(
[2, 3],
[
_ConvTasNetParams(128, 40, 128, 256, 3, 7, 2),
_ConvTasNetParams(256, 40, 128, 256, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 256, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 256, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 512, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 512, 3, 7, 2),
_ConvTasNetParams(512, 40, 256, 256, 3, 7, 2),
_ConvTasNetParams(512, 40, 256, 512, 3, 7, 2),
_ConvTasNetParams(512, 40, 256, 512, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 512, 3, 6, 4),
_ConvTasNetParams(512, 40, 128, 512, 3, 4, 6),
_ConvTasNetParams(512, 40, 128, 512, 3, 8, 3),
_ConvTasNetParams(512, 32, 128, 512, 3, 8, 3),
_ConvTasNetParams(512, 16, 128, 512, 3, 8, 3),
],
)))
def test_paper_configuration(self, num_sources, model_params):
"""ConvTasNet model works on the valid configurations in the paper"""
batch_size = 32
num_frames = 8000

model = ConvTasNet(
num_sources=num_sources,
enc_kernel_size=model_params.enc_kernel_size,
enc_num_feats=model_params.enc_num_feats,
msk_kernel_size=model_params.msk_kernel_size,
msk_num_feats=model_params.msk_num_feats,
msk_num_hidden_feats=model_params.msk_num_hidden_feats,
msk_num_layers=model_params.msk_num_layers,
msk_num_stacks=model_params.msk_num_stacks,
)
tensor = torch.rand(batch_size, 1, num_frames)
output = model(tensor)

assert output.shape == (batch_size, num_sources, num_frames)
1 change: 1 addition & 0 deletions torchaudio/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .wav2letter import *
from .wavernn import *
from .conv_tasnet import ConvTasNet
Loading