Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torchaudio.models import Wav2Letter, _MelResNet, _UpsampleNetwork
from torchaudio.models import Wav2Letter, _MelResNet, _UpsampleNetwork, _WaveRNN

from . import common_utils

Expand Down Expand Up @@ -81,3 +81,62 @@ def test_waveform(self):

assert out1.size() == (n_batch, n_freq, total_scale * (n_time - kernel_size + 1))
assert out2.size() == (n_batch, n_output, total_scale * (n_time - kernel_size + 1))


class TestWaveRNN(common_utils.TorchaudioTestCase):

def test_waveform(self):
"""Validate the output dimensions of a _WaveRNN model in waveform mode.
"""

upsample_scales = [5, 5, 8]
n_rnn = 512
n_fc = 512
n_bits = 9
sample_rate = 24000
hop_length = 200
n_batch = 2
n_time = 200
n_freq = 100
n_output = 256
n_res_block = 10
n_hidden = 128
kernel_size = 5
mode = 'waveform'

model = _WaveRNN(upsample_scales, n_bits, sample_rate, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output, mode)

x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
mels = torch.rand(n_batch, 1, n_freq, n_time)
out = model(x, mels)

assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), 2 ** n_bits)

def test_mol(self):
"""Validate the output dimensions of a _WaveRNN model in mol mode.
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, explanation of test is required. otherwise it will be difficult to make proper changes to this test later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test comments added.

I updated this change to each stack PR (#751, #724, #735 )

upsample_scales = [5, 5, 8]
n_rnn = 512
n_fc = 512
n_bits = 9
sample_rate = 24000
hop_length = 200
n_batch = 2
n_time = 200
n_freq = 100
n_output = 256
n_res_block = 10
n_hidden = 128
kernel_size = 5
mode = 'mol'

model = _WaveRNN(upsample_scales, n_bits, sample_rate, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output, mode)

x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
mels = torch.rand(n_batch, 1, n_freq, n_time)
out = model(x, mels)

assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), 30)
139 changes: 138 additions & 1 deletion torchaudio/models/_wavernn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import List

import torch
from torch import Tensor
from torch import nn

__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork"]
__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork", "_WaveRNN"]


class _ResBlock(nn.Module):
Expand Down Expand Up @@ -192,3 +193,139 @@ def forward(self, specgram: Tensor) -> Tensor:
upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent:-self.indent]

return upsampling_output, resnet_output


class _WaveRNN(nn.Module):
r"""WaveRNN model based on the implementation from `fatchord <https://github.com/fatchord/WaveRNN>`_.

The original implementation was introduced in
`"Efficient Neural Audio Synthesis" <https://arxiv.org/pdf/1802.08435.pdf>`_.
The input channels of waveform and spectrogram have to be 1. The product of
`upsample_scales` must equal `hop_length`.

Args:
upsample_scales: the list of upsample scales
n_bits: the bits of output waveform
sample_rate: the rate of audio dimensions (samples per second)
hop_length: the number of samples between the starts of consecutive frames
n_res_block: the number of ResBlock in stack (default=10)
n_rnn: the dimension of RNN layer (default=512)
n_fc: the dimension of fully connected layer (default=512)
kernel_size: the number of kernel size in the first Conv1d layer (default=5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: see here

(Default: 5)

(similar below)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since other blocks #751 #724 have the same format (use =). I plan to open a separate pull request to change all of them later. Now I will keep them same here.

n_freq: the number of bins in a spectrogram (default=128)
n_hidden: the number of hidden dimensions (default=128)
n_output: the number of output dimensions (default=128)
mode: the mode of waveform in ['waveform', 'mol'] (default='waveform')

Example
>>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_bits=9, sample_rate=24000, hop_length=200)
>>> waveform, sample_rate = torchaudio.load(file)
>>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
>>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time)
>>> output = wavernn(waveform, specgram)
>>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits)
"""

def __init__(self,
upsample_scales: List[int],
n_bits: int,
sample_rate: int,
hop_length: int,
n_res_block: int = 10,
n_rnn: int = 512,
n_fc: int = 512,
kernel_size: int = 5,
n_freq: int = 128,
n_hidden: int = 128,
n_output: int = 128,
mode: str = 'waveform') -> None:
super().__init__()

self.mode = mode
self.kernel_size = kernel_size

if self.mode == 'waveform':
self.n_classes = 2 ** n_bits
elif self.mode == 'mol':
self.n_classes = 30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you throw error when mode is invalid?

Copy link
Contributor Author

@jimchen90 jimchen90 Jun 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in this PR.

Comment on lines +247 to +250
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's replace mode and n_bits parameters simply by n_classes.

cc comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I changed the mode to loss and replaced n_bits by n_classes in #797 .

else:
raise ValueError(f"Expected mode: `waveform` or `mol`, but found {self.mode}")

self.n_rnn = n_rnn
self.n_aux = n_output // 4
self.hop_length = hop_length
self.sample_rate = sample_rate

total_scale = 1
for upsample_scale in upsample_scales:
total_scale *= upsample_scale
if total_scale != self.hop_length:
raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}")

self.upsample = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's change to n_hidden_resblock and n_output_upsample. Thoughts?

Copy link
Contributor Author

@jimchen90 jimchen90 Jul 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the name as n_hidden_resblock and n_output_melresnet in #797 .
because the n_output is used in melresenet block and melresnet block is one part of upsample block, so I use n_output_melresnet. Any suggestion?

self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn)

self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True)
self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True)

self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)

self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc)
self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc)
self.fc3 = nn.Linear(n_fc, self.n_classes)

def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
r"""Pass the input through the _WaveRNN model.

Args:
waveform: the input waveform to the _WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length)
specgram: the input spectrogram to the _WaveRNN layer (n_batch, 1, n_freq, n_time)

Return:
Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits)
"""

assert waveform.size(1) == 1, 'Require the input channel of waveform is 1'
assert specgram.size(1) == 1, 'Require the input channel of specgram is 1'
# remove channel dimension until the end
waveform, specgram = waveform.squeeze(1), specgram.squeeze(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

        # remove channel dimension until the end
        waveform, specgram = waveform.squeeze(1), specgram.squeeze(1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment has been added.


batch_size = waveform.size(0)
h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
# output of upsample:
# specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale)
# aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
specgram, aux = self.upsample(specgram)
specgram = specgram.transpose(1, 2)
aux = aux.transpose(1, 2)

aux_idx = [self.n_aux * i for i in range(5)]
a1 = aux[:, :, aux_idx[0]:aux_idx[1]]
a2 = aux[:, :, aux_idx[1]:aux_idx[2]]
a3 = aux[:, :, aux_idx[2]:aux_idx[3]]
a4 = aux[:, :, aux_idx[3]:aux_idx[4]]

x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1)
x = self.fc(x)
res = x
x, _ = self.rnn1(x, h1)

x = x + res
res = x
x = torch.cat([x, a2], dim=-1)
x, _ = self.rnn2(x, h2)

x = x + res
x = torch.cat([x, a3], dim=-1)
x = self.fc1(x)
x = self.relu1(x)

x = torch.cat([x, a4], dim=-1)
x = self.fc2(x)
x = self.relu2(x)
x = self.fc3(x)

# bring back channel dimension
return x.unsqueeze(1)