From 31d403999025d63d2b15890e9fec754a4a2a3d14 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Tue, 1 Sep 2020 17:05:57 +0000 Subject: [PATCH 1/7] Add Conv-TasNet model --- .../source_separation/conv_tasnet/model.py | 325 ++++++++++++++++++ 1 file changed, 325 insertions(+) create mode 100644 examples/source_separation/conv_tasnet/model.py diff --git a/examples/source_separation/conv_tasnet/model.py b/examples/source_separation/conv_tasnet/model.py new file mode 100644 index 0000000000..a2a4824214 --- /dev/null +++ b/examples/source_separation/conv_tasnet/model.py @@ -0,0 +1,325 @@ +"""Implements Conv-TasNet with building blocks of it.""" + +from typing import Tuple, Optional + +import torch + + +class ConvBlock(torch.nn.Module): + """1D Convolutional block. + + Args: + channels (int): The number of input/output channels, + hidden_channels (int): The number of channels in the internal layers, . + kernel_size (int): The convolution kernel size of the middle layer,

. + padding (int): Padding value of the convolution in the middle layer. + dilation (int): Dilation value of the convolution in the middle layer. + causal (bool): Switch causal/non-causal implementation. + no_redisual (bool): Disable residual block/output. + + References: + - Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation + Luo, Yi and Mesgarani, Nima + https://arxiv.org/abs/1809.07454 + """ + + def __init__( + self, + io_channels: int, + hidden_channels: int, + kernel_size: int, + padding: int, + dilation: int = 1, + causal: bool = False, + no_residual: bool = False, + ): + super().__init__() + + if causal: + raise NotImplementedError("causal=True is not implemented") + + self.conv_layers = torch.nn.Sequential( + torch.nn.Conv1d( + in_channels=io_channels, out_channels=hidden_channels, kernel_size=1 + ), + torch.nn.PReLU(), + torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08), + torch.nn.Conv1d( + in_channels=hidden_channels, + out_channels=hidden_channels, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + groups=hidden_channels, + ), + torch.nn.PReLU(), + torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08), + ) + + self.res_out = ( + None + if no_residual + else torch.nn.Conv1d( + in_channels=hidden_channels, out_channels=io_channels, kernel_size=1 + ) + ) + self.skip_out = torch.nn.Conv1d( + in_channels=hidden_channels, out_channels=io_channels, kernel_size=1 + ) + + def forward( + self, input: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + feature = self.conv_layers(input) + if self.res_out is None: + residual = None + else: + residual = self.res_out(feature) + skip_out = self.skip_out(feature) + return residual, skip_out + + +class MaskGenerator(torch.nn.Module): + """TCN (Temporal Convolution Network) Separation Module + + Generates masks for separation. + + Args: + input_dim (int): Input feature dimension, . + num_sources (int): The number of sources to separate. + kernel_size (int): The convolution kernel size of conv blocks,

. + num_featrs (int): Input/output feature dimenstion of conv blocks, . + num_hidden (int): Intermediate feature dimention of conv blocks, + num_layers (int): The number of conv blocks in one stack, . + num_stacks (int): The number of conv block stacks, . + causal (bool): Switch causal/non-causal implementation. + + References: + - Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation + Luo, Yi and Mesgarani, Nima + https://arxiv.org/abs/1809.07454 + """ + + def __init__( + self, + input_dim: int, + num_sources: int, + kernel_size: int, + num_feats: int, + num_hidden: int, + num_layers: int, + num_stacks: int, + causal: bool = False, + ): + if causal: + raise NotImplementedError("causal=True is not implemented") + + super().__init__() + + self.input_dim = input_dim + self.num_sources = num_sources + + self.norm_layers = torch.nn.Sequential( + torch.nn.GroupNorm(num_groups=1, num_channels=input_dim, eps=1e-8), + torch.nn.Conv1d( + in_channels=input_dim, out_channels=num_feats, kernel_size=1 + ), + ) + self.receptive_field = 0 + self.conv_layers = torch.nn.ModuleList([]) + for s in range(num_stacks): + for l in range(num_layers): + multi = 2 ** l + self.conv_layers.append( + ConvBlock( + io_channels=num_feats, + hidden_channels=num_hidden, + kernel_size=kernel_size, + dilation=multi, + padding=multi, + causal=causal, + # The last ConvBlock does not need residual + no_residual=(l == (num_layers - 1) and s == (num_stacks - 1)), + ) + ) + self.receptive_field += ( + kernel_size if s == 0 and l == 0 else (kernel_size - 1) * multi + ) + self.output_layer = torch.nn.Sequential( + torch.nn.PReLU(), + torch.nn.Conv1d( + in_channels=num_feats, + out_channels=input_dim * num_sources, + kernel_size=1, + ), + torch.nn.Sigmoid(), + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Generate separation mask. + + Args: + input (torch.Tensor): 3D Tensor with shape [batch, features, frames] + + Returns: + torch.Tensor: shape [batch, num_sources, features, frames] + """ + batch_size = input.shape[0] + feats = self.norm_layers(input) + output = 0.0 + for layer in self.conv_layers: + residual, skip = layer(feats) + if residual is not None: # the last conv layer does not produce residual + feats = feats + residual + output = output + skip + output = self.output_layer(output) + return output.view(batch_size, self.num_sources, self.input_dim, -1) + + +class ConvTasNet(torch.nn.Module): + """Conv-TasNet: a fully-convolutional time-domain audio separation network + + Args: + num_sources (int): The number of sources to split. + enc_kernel_size (int): The convolution kernel size of the encoder/decoder, . + enc_num_feats (int): The feature dimensions passed to mask generator, . + msk_kernel_size (int): The convolution kernel size of the mask generator,

. + msk_num_feats (int): The input/output feature dimension of conv block in the mask generator, . + msk_num_hidden_feats (int): The internal feature dimension of conv block of the mask generator, . + msk_num_layers (int): The number of layers in one conv block of the mask generator, . + msk_num_stacks (int): The numbr of conv blocks of the mask generator, . + causal (bool): Switch causal/non-causal implementation. + + References: + - Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation + Luo, Yi and Mesgarani, Nima + https://arxiv.org/abs/1809.07454 + """ + + def __init__( + self, + num_sources: int = 2, + # encoder/decoder parameters + enc_kernel_size: int = 16, + enc_num_feats: int = 512, + # mask generator parameters + msk_kernel_size: int = 3, + msk_num_feats: int = 128, + msk_num_hidden_feats: int = 512, + msk_num_layers: int = 8, + msk_num_stacks: int = 3, + causal: bool = False, + ): + super().__init__() + + if causal: + raise NotImplementedError("causal=True is not implemented") + + self.num_sources = num_sources + self.enc_num_feats = enc_num_feats + self.enc_kernel_size = enc_kernel_size + self.enc_stride = enc_kernel_size // 2 + + self.encoder = torch.nn.Conv1d( + in_channels=1, + out_channels=enc_num_feats, + kernel_size=enc_kernel_size, + stride=self.enc_stride, + padding=self.enc_stride, + bias=False, + ) + self.mask_generator = MaskGenerator( + input_dim=enc_num_feats, + num_sources=num_sources, + kernel_size=msk_kernel_size, + num_feats=msk_num_feats, + num_hidden=msk_num_hidden_feats, + num_layers=msk_num_layers, + num_stacks=msk_num_stacks, + ) + self.decoder = torch.nn.ConvTranspose1d( + in_channels=enc_num_feats, + out_channels=1, + kernel_size=enc_kernel_size, + stride=self.enc_stride, + padding=self.enc_stride, + bias=False, + ) + + def _pad_input(self, input: torch.Tensor) -> Tuple[torch.Tensor, int]: + """Pad input Tensor so that the end of the input tensor corresponds with + + 1. (if kernel size is odd) the center of the last convolution kernel + or 2. (if kernel size is even) the end of the first half of the last convolution kernel + + Assuming that the resulting Tensor will be zero-padded with the size of stride + on the both ends in Conv1D + + |<--- k_1 --->| + | | |<-- k_n-1 -->| + | | | |<--- k_n --->| + | | | | | + | | | | | + | v v v | + |<---->|<--- input signal --->|<--->|<---->| + stride PAD stride + + Args: + input (torch.Tensor): 3D Tensor with shape (batch_size, channels==1, frames) + + Returns: + torch.Tensor: Padded Tensor + int: Number of paddings performed + """ + batch_size, num_channels, num_frames = input.shape + is_odd = self.enc_kernel_size % 2 + num_strides = (num_frames - is_odd) // self.enc_stride + num_remainings = num_frames - (is_odd + num_strides * self.enc_stride) + if num_remainings == 0: + return input, 0 + + num_paddings = self.enc_stride - num_remainings + pad = torch.zeros( + batch_size, + num_channels, + num_paddings, + dtype=input.dtype, + device=input.device, + ) + return torch.cat([input, pad], 2), num_paddings + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Perform source separation. Generate audio source waveforms. + + Args: + input (torch.Tensor): 3D Tensor with shape [batch, channel==1, frames] + + Returns: + torch.Tensor: 3D Tensor with shape [batch, channel==num_sources, frames] + """ + if input.ndim != 3 or input.shape[1] != 1: + raise ValueError( + f"Expected 3D tensor (batch, channel==1, frames). Found: {input.shape}" + ) + + # B: batch size + # L: input frame length + # L': padded input frame length + # F: feature dimension + # M: feature frame length + # S: number of sources + + padded, num_pads = self._pad_input(input) # B, 1, L' + batch_size, num_padded_frames = padded.shape[0], padded.shape[2] + feats = self.encoder(padded) # B, F, M + masked = self.mask_generator(feats) * feats.unsqueeze(1) # B, S, F, M + masked = masked.view( + batch_size * self.num_sources, self.enc_num_feats, -1 + ) # B*S, F, M + decoded = self.decoder(masked) # B*S, 1, L' + output = decoded.view( + batch_size, self.num_sources, num_padded_frames + ) # B, S, L' + if num_pads > 0: + output = output[..., :-num_pads] # B, S, L + return output From 73ed48aa2136721b968fc643f69dc8cfa4751012 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Tue, 22 Sep 2020 16:44:25 -0700 Subject: [PATCH 2/7] Remove causal parameter and add note --- .../source_separation/conv_tasnet/model.py | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/examples/source_separation/conv_tasnet/model.py b/examples/source_separation/conv_tasnet/model.py index a2a4824214..3d472a389a 100644 --- a/examples/source_separation/conv_tasnet/model.py +++ b/examples/source_separation/conv_tasnet/model.py @@ -14,9 +14,11 @@ class ConvBlock(torch.nn.Module): kernel_size (int): The convolution kernel size of the middle layer,

. padding (int): Padding value of the convolution in the middle layer. dilation (int): Dilation value of the convolution in the middle layer. - causal (bool): Switch causal/non-causal implementation. no_redisual (bool): Disable residual block/output. + Note: + This implementation corresponds to the "causal" setting in the paper. + References: - Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation Luo, Yi and Mesgarani, Nima @@ -30,14 +32,10 @@ def __init__( kernel_size: int, padding: int, dilation: int = 1, - causal: bool = False, no_residual: bool = False, ): super().__init__() - if causal: - raise NotImplementedError("causal=True is not implemented") - self.conv_layers = torch.nn.Sequential( torch.nn.Conv1d( in_channels=io_channels, out_channels=hidden_channels, kernel_size=1 @@ -92,7 +90,9 @@ class MaskGenerator(torch.nn.Module): num_hidden (int): Intermediate feature dimention of conv blocks, num_layers (int): The number of conv blocks in one stack, . num_stacks (int): The number of conv block stacks, . - causal (bool): Switch causal/non-causal implementation. + + Note: + This implementation corresponds to the "causal" setting in the paper. References: - Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation @@ -109,11 +109,7 @@ def __init__( num_hidden: int, num_layers: int, num_stacks: int, - causal: bool = False, ): - if causal: - raise NotImplementedError("causal=True is not implemented") - super().__init__() self.input_dim = input_dim @@ -137,7 +133,6 @@ def __init__( kernel_size=kernel_size, dilation=multi, padding=multi, - causal=causal, # The last ConvBlock does not need residual no_residual=(l == (num_layers - 1) and s == (num_stacks - 1)), ) @@ -188,7 +183,9 @@ class ConvTasNet(torch.nn.Module): msk_num_hidden_feats (int): The internal feature dimension of conv block of the mask generator, . msk_num_layers (int): The number of layers in one conv block of the mask generator, . msk_num_stacks (int): The numbr of conv blocks of the mask generator, . - causal (bool): Switch causal/non-causal implementation. + + Note: + This implementation corresponds to the "causal" setting in the paper. References: - Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation @@ -208,13 +205,9 @@ def __init__( msk_num_hidden_feats: int = 512, msk_num_layers: int = 8, msk_num_stacks: int = 3, - causal: bool = False, ): super().__init__() - if causal: - raise NotImplementedError("causal=True is not implemented") - self.num_sources = num_sources self.enc_num_feats = enc_num_feats self.enc_kernel_size = enc_kernel_size From 26d934f67a6b33060f33ef5af2c9b5497a43026c Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Tue, 22 Sep 2020 17:04:40 -0700 Subject: [PATCH 3/7] Fix not on causal --- examples/source_separation/conv_tasnet/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/source_separation/conv_tasnet/model.py b/examples/source_separation/conv_tasnet/model.py index 3d472a389a..dea28f7b2b 100644 --- a/examples/source_separation/conv_tasnet/model.py +++ b/examples/source_separation/conv_tasnet/model.py @@ -17,7 +17,7 @@ class ConvBlock(torch.nn.Module): no_redisual (bool): Disable residual block/output. Note: - This implementation corresponds to the "causal" setting in the paper. + This implementation corresponds to the "non-causal" setting in the paper. References: - Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation @@ -92,7 +92,7 @@ class MaskGenerator(torch.nn.Module): num_stacks (int): The number of conv block stacks, . Note: - This implementation corresponds to the "causal" setting in the paper. + This implementation corresponds to the "non-causal" setting in the paper. References: - Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation @@ -185,7 +185,7 @@ class ConvTasNet(torch.nn.Module): msk_num_stacks (int): The numbr of conv blocks of the mask generator, . Note: - This implementation corresponds to the "causal" setting in the paper. + This implementation corresponds to the "non-causal" setting in the paper. References: - Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation From a7745f8fe3418328ab581ad0d11cb7486104505d Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Wed, 23 Sep 2020 00:11:00 +0000 Subject: [PATCH 4/7] Add link to the reference --- examples/source_separation/conv_tasnet/model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/source_separation/conv_tasnet/model.py b/examples/source_separation/conv_tasnet/model.py index dea28f7b2b..babaef7820 100644 --- a/examples/source_separation/conv_tasnet/model.py +++ b/examples/source_separation/conv_tasnet/model.py @@ -1,4 +1,7 @@ -"""Implements Conv-TasNet with building blocks of it.""" +"""Implements Conv-TasNet with building blocks of it. + +Based on https://github.com/naplab/Conv-TasNet/tree/e66d82a8f956a69749ec8a4ae382217faa097c5c +""" from typing import Tuple, Optional From 10bb60fd8d1a62a3fee8f8d40456bb7f0c25f1b6 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Wed, 23 Sep 2020 17:13:35 +0000 Subject: [PATCH 5/7] dissovle input/output sequential --- .../source_separation/conv_tasnet/model.py | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/examples/source_separation/conv_tasnet/model.py b/examples/source_separation/conv_tasnet/model.py index babaef7820..2c2d5029c9 100644 --- a/examples/source_separation/conv_tasnet/model.py +++ b/examples/source_separation/conv_tasnet/model.py @@ -118,12 +118,11 @@ def __init__( self.input_dim = input_dim self.num_sources = num_sources - self.norm_layers = torch.nn.Sequential( - torch.nn.GroupNorm(num_groups=1, num_channels=input_dim, eps=1e-8), - torch.nn.Conv1d( - in_channels=input_dim, out_channels=num_feats, kernel_size=1 - ), - ) + self.input_norm = torch.nn.GroupNorm( + num_groups=1, num_channels=input_dim, eps=1e-8) + self.input_conv = torch.nn.Conv1d( + in_channels=input_dim, out_channels=num_feats, kernel_size=1) + self.receptive_field = 0 self.conv_layers = torch.nn.ModuleList([]) for s in range(num_stacks): @@ -143,15 +142,12 @@ def __init__( self.receptive_field += ( kernel_size if s == 0 and l == 0 else (kernel_size - 1) * multi ) - self.output_layer = torch.nn.Sequential( - torch.nn.PReLU(), - torch.nn.Conv1d( + self.output_prelu = torch.nn.PReLU() + self.output_conv = torch.nn.Conv1d( in_channels=num_feats, out_channels=input_dim * num_sources, kernel_size=1, - ), - torch.nn.Sigmoid(), - ) + ) def forward(self, input: torch.Tensor) -> torch.Tensor: """Generate separation mask. @@ -163,14 +159,17 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: torch.Tensor: shape [batch, num_sources, features, frames] """ batch_size = input.shape[0] - feats = self.norm_layers(input) + feats = self.input_norm(input) + feats = self.input_conv(feats) output = 0.0 for layer in self.conv_layers: residual, skip = layer(feats) if residual is not None: # the last conv layer does not produce residual feats = feats + residual output = output + skip - output = self.output_layer(output) + output = self.output_prelu(output) + output = self.output_conv(output) + output = torch.sigmoid(output) return output.view(batch_size, self.num_sources, self.input_dim, -1) From 3592081ea53987f0000a833dd2d4cd548d23bb61 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Wed, 23 Sep 2020 17:24:58 +0000 Subject: [PATCH 6/7] Update model name --- examples/source_separation/conv_tasnet/model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/source_separation/conv_tasnet/model.py b/examples/source_separation/conv_tasnet/model.py index 2c2d5029c9..f7ef5da187 100644 --- a/examples/source_separation/conv_tasnet/model.py +++ b/examples/source_separation/conv_tasnet/model.py @@ -241,14 +241,15 @@ def __init__( bias=False, ) - def _pad_input(self, input: torch.Tensor) -> Tuple[torch.Tensor, int]: + def _align_num_frames_with_strides(self, input: torch.Tensor) -> Tuple[torch.Tensor, int]: """Pad input Tensor so that the end of the input tensor corresponds with 1. (if kernel size is odd) the center of the last convolution kernel or 2. (if kernel size is even) the end of the first half of the last convolution kernel - Assuming that the resulting Tensor will be zero-padded with the size of stride - on the both ends in Conv1D + Assumption: + The resulting Tensor will be padded with the size of stride (== kernel_width // 2) + on the both ends in Conv1D |<--- k_1 --->| | | |<-- k_n-1 -->| @@ -304,7 +305,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # M: feature frame length # S: number of sources - padded, num_pads = self._pad_input(input) # B, 1, L' + padded, num_pads = self._align_num_frames_with_strides(input) # B, 1, L' batch_size, num_padded_frames = padded.shape[0], padded.shape[2] feats = self.encoder(padded) # B, F, M masked = self.mask_generator(feats) * feats.unsqueeze(1) # B, S, F, M From 76d5d61c5ca35ab07ad1bbfa2c9f824bf6d7470b Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 25 Sep 2020 15:04:32 +0000 Subject: [PATCH 7/7] Move model to library and add test --- test/torchaudio_unittest/models_test.py | 67 ++++++++++++++++++- torchaudio/models/__init__.py | 1 + .../models/conv_tasnet.py | 0 3 files changed, 67 insertions(+), 1 deletion(-) rename examples/source_separation/conv_tasnet/model.py => torchaudio/models/conv_tasnet.py (100%) diff --git a/test/torchaudio_unittest/models_test.py b/test/torchaudio_unittest/models_test.py index f9e96edd27..c389f2cc06 100644 --- a/test/torchaudio_unittest/models_test.py +++ b/test/torchaudio_unittest/models_test.py @@ -1,5 +1,15 @@ +import itertools +from collections import namedtuple + import torch -from torchaudio.models import Wav2Letter, MelResNet, UpsampleNetwork, WaveRNN +from torchaudio.models import ( + Wav2Letter, + MelResNet, + UpsampleNetwork, + WaveRNN, + ConvTasNet, +) +from parameterized import parameterized from torchaudio_unittest import common_utils @@ -115,3 +125,58 @@ def test_waveform(self): out = model(x, mels) assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes) + + +_ConvTasNetParams = namedtuple( + '_ConvTasNetParams', + [ + 'enc_num_feats', + 'enc_kernel_size', + 'msk_num_feats', + 'msk_num_hidden_feats', + 'msk_kernel_size', + 'msk_num_layers', + 'msk_num_stacks', + ] +) + + +class TestConvTasNet(common_utils.TorchaudioTestCase): + @parameterized.expand(list(itertools.product( + [2, 3], + [ + _ConvTasNetParams(128, 40, 128, 256, 3, 7, 2), + _ConvTasNetParams(256, 40, 128, 256, 3, 7, 2), + _ConvTasNetParams(512, 40, 128, 256, 3, 7, 2), + _ConvTasNetParams(512, 40, 128, 256, 3, 7, 2), + _ConvTasNetParams(512, 40, 128, 512, 3, 7, 2), + _ConvTasNetParams(512, 40, 128, 512, 3, 7, 2), + _ConvTasNetParams(512, 40, 256, 256, 3, 7, 2), + _ConvTasNetParams(512, 40, 256, 512, 3, 7, 2), + _ConvTasNetParams(512, 40, 256, 512, 3, 7, 2), + _ConvTasNetParams(512, 40, 128, 512, 3, 6, 4), + _ConvTasNetParams(512, 40, 128, 512, 3, 4, 6), + _ConvTasNetParams(512, 40, 128, 512, 3, 8, 3), + _ConvTasNetParams(512, 32, 128, 512, 3, 8, 3), + _ConvTasNetParams(512, 16, 128, 512, 3, 8, 3), + ], + ))) + def test_paper_configuration(self, num_sources, model_params): + """ConvTasNet model works on the valid configurations in the paper""" + batch_size = 32 + num_frames = 8000 + + model = ConvTasNet( + num_sources=num_sources, + enc_kernel_size=model_params.enc_kernel_size, + enc_num_feats=model_params.enc_num_feats, + msk_kernel_size=model_params.msk_kernel_size, + msk_num_feats=model_params.msk_num_feats, + msk_num_hidden_feats=model_params.msk_num_hidden_feats, + msk_num_layers=model_params.msk_num_layers, + msk_num_stacks=model_params.msk_num_stacks, + ) + tensor = torch.rand(batch_size, 1, num_frames) + output = model(tensor) + + assert output.shape == (batch_size, num_sources, num_frames) diff --git a/torchaudio/models/__init__.py b/torchaudio/models/__init__.py index 8e05b8b509..2abee7b437 100644 --- a/torchaudio/models/__init__.py +++ b/torchaudio/models/__init__.py @@ -1,2 +1,3 @@ from .wav2letter import * from .wavernn import * +from .conv_tasnet import ConvTasNet diff --git a/examples/source_separation/conv_tasnet/model.py b/torchaudio/models/conv_tasnet.py similarity index 100% rename from examples/source_separation/conv_tasnet/model.py rename to torchaudio/models/conv_tasnet.py