diff --git a/test/test_models.py b/test/test_models.py index 57c86cc637..109112ee15 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -36,18 +36,20 @@ def test_mfcc(self): class TestMelResNet(common_utils.TorchaudioTestCase): def test_waveform(self): + """Validate the output dimensions of a _MelResNet block. + """ - batch_size = 2 - num_features = 200 - input_dims = 100 - output_dims = 128 - res_blocks = 10 - hidden_dims = 128 - pad = 2 + n_batch = 2 + n_time = 200 + n_freq = 100 + n_output = 128 + n_res_block = 10 + n_hidden = 128 + kernel_size = 5 - model = _MelResNet(res_blocks, input_dims, hidden_dims, output_dims, pad) + model = _MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size) - x = torch.rand(batch_size, input_dims, num_features) + x = torch.rand(n_batch, n_freq, n_time) out = model(x) - assert out.size() == (batch_size, output_dims, num_features - pad * 2) + assert out.size() == (n_batch, n_output, n_time - kernel_size + 1) diff --git a/torchaudio/models/_wavernn.py b/torchaudio/models/_wavernn.py index 04155fb87c..4d9dc3144a 100644 --- a/torchaudio/models/_wavernn.py +++ b/torchaudio/models/_wavernn.py @@ -5,101 +5,83 @@ class _ResBlock(nn.Module): - r"""This is a ResNet block layer. This layer is based on the paper "Deep Residual Learning - for Image Recognition". Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. CVPR, 2016. - It is a block used in WaveRNN. WaveRNN is based on the paper "Efficient Neural Audio Synthesis". - Nal Kalchbrenner, Erich Elsen, Karen Simonyan, Seb Noury, Norman Casagrande, Edward Lockhart, - Florian Stimberg, Aaron van den Oord, Sander Dieleman, Koray Kavukcuoglu. arXiv:1802.08435, 2018. + r"""ResNet block based on "Deep Residual Learning for Image Recognition" + + The paper link is https://arxiv.org/pdf/1512.03385.pdf. Args: - num_dims: the number of compute dimensions in the input (default=128). + n_freq: the number of bins in a spectrogram (default=128) - Examples:: - >>> resblock = _ResBlock(num_dims=128) - >>> input = torch.rand(10, 128, 512) - >>> output = resblock(input) + Examples + >>> resblock = _ResBlock() + >>> input = torch.rand(10, 128, 512) # a random spectrogram + >>> output = resblock(input) # shape: (10, 128, 512) """ - def __init__(self, num_dims: int = 128) -> None: + def __init__(self, n_freq: int = 128) -> None: super().__init__() self.resblock_model = nn.Sequential( - nn.Conv1d(in_channels=num_dims, out_channels=num_dims, kernel_size=1, bias=False), - nn.BatchNorm1d(num_dims), + nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False), + nn.BatchNorm1d(n_freq), nn.ReLU(inplace=True), - nn.Conv1d(in_channels=num_dims, out_channels=num_dims, kernel_size=1, bias=False), - nn.BatchNorm1d(num_dims) + nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False), + nn.BatchNorm1d(n_freq) ) - def forward(self, x: Tensor) -> Tensor: + def forward(self, specgram: Tensor) -> Tensor: r"""Pass the input through the _ResBlock layer. - Args: - x: the input sequence to the _ResBlock layer (required). + specgram (Tensor): the input sequence to the _ResBlock layer (n_batch, n_freq, n_time). - Shape: - - x: :math:`(N, S, T)`. - - output: :math:`(N, S, T)`. - where N is the batch size, S is the number of input sequence, - T is the length of input sequence. + Return: + Tensor shape: (n_batch, n_freq, n_time) """ - residual = x - return self.resblock_model(x) + residual + return self.resblock_model(specgram) + specgram class _MelResNet(nn.Module): - r"""This is a MelResNet layer based on a stack of ResBlocks. It is a block used in WaveRNN. - WaveRNN is based on the paper "Efficient Neural Audio Synthesis". Nal Kalchbrenner, Erich Elsen, - Karen Simonyan, Seb Noury, Norman Casagrande, Edward Lockhart, Florian Stimberg, Aaron van den Oord, - Sander Dieleman, Koray Kavukcuoglu. arXiv:1802.08435, 2018. + r"""MelResNet layer uses a stack of ResBlocks on spectrogram. Args: - res_blocks: the number of ResBlock in stack (default=10). - input_dims: the number of input sequence (default=100). - hidden_dims: the number of compute dimensions (default=128). - output_dims: the number of output sequence (default=128). - pad: the number of kernal size (pad * 2 + 1) in the first Conv1d layer (default=2). - - Examples:: - >>> melresnet = _MelResNet(res_blocks=10, input_dims=100, - hidden_dims=128, output_dims=128, pad=2) - >>> input = torch.rand(10, 100, 512) - >>> output = melresnet(input) + n_res_block: the number of ResBlock in stack (default=10) + n_freq: the number of bins in a spectrogram (default=128) + n_hidden: the number of hidden dimensions (default=128) + n_output: the number of output dimensions (default=128) + kernel_size: the number of kernel size in the first Conv1d layer (default=5) + + Examples + >>> melresnet = _MelResNet() + >>> input = torch.rand(10, 128, 512) # a random spectrogram + >>> output = melresnet(input) # shape: (10, 128, 508) """ - def __init__(self, res_blocks: int = 10, - input_dims: int = 100, - hidden_dims: int = 128, - output_dims: int = 128, - pad: int = 2) -> None: + def __init__(self, + n_res_block: int = 10, + n_freq: int = 128, + n_hidden: int = 128, + n_output: int = 128, + kernel_size: int = 5) -> None: super().__init__() - kernel_size = pad * 2 + 1 - ResBlocks = [] - - for i in range(res_blocks): - ResBlocks.append(_ResBlock(hidden_dims)) + ResBlocks = [_ResBlock(n_hidden) for _ in range(n_res_block)] self.melresnet_model = nn.Sequential( - nn.Conv1d(in_channels=input_dims, out_channels=hidden_dims, kernel_size=kernel_size, bias=False), - nn.BatchNorm1d(hidden_dims), + nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False), + nn.BatchNorm1d(n_hidden), nn.ReLU(inplace=True), *ResBlocks, - nn.Conv1d(in_channels=hidden_dims, out_channels=output_dims, kernel_size=1) + nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1) ) - def forward(self, x: Tensor) -> Tensor: + def forward(self, specgram: Tensor) -> Tensor: r"""Pass the input through the _MelResNet layer. - Args: - x: the input sequence to the _MelResNet layer (required). + specgram (Tensor): the input sequence to the _MelResNet layer (n_batch, n_freq, n_time). - Shape: - - x: :math:`(N, S, T)`. - - output: :math:`(N, P, T - 2 * pad)`. - where N is the batch size, S is the number of input sequence, - P is the number of output sequence, T is the length of input sequence. + Return: + Tensor shape: (n_batch, n_output, n_time - kernel_size + 1) """ - return self.melresnet_model(x) + return self.melresnet_model(specgram)