Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 11 additions & 36 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@ def test_waveform(self):
for upsample_scale in upsample_scales:
total_scale *= upsample_scale

model = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
model = _UpsampleNetwork(upsample_scales,
n_res_block,
n_freq,
n_hidden,
n_output,
kernel_size)

x = torch.rand(n_batch, n_freq, n_time)
out1, out2 = model(x)
Expand All @@ -86,14 +91,13 @@ def test_waveform(self):
class TestWaveRNN(common_utils.TorchaudioTestCase):

def test_waveform(self):
"""Validate the output dimensions of a _WaveRNN model in waveform mode.
"""Validate the output dimensions of a _WaveRNN model.
"""

upsample_scales = [5, 5, 8]
n_rnn = 512
n_fc = 512
n_bits = 9
sample_rate = 24000
n_classes = 512
hop_length = 200
n_batch = 2
n_time = 200
Expand All @@ -102,41 +106,12 @@ def test_waveform(self):
n_res_block = 10
n_hidden = 128
kernel_size = 5
mode = 'waveform'

model = _WaveRNN(upsample_scales, n_bits, sample_rate, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output, mode)
model = _WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)

x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
mels = torch.rand(n_batch, 1, n_freq, n_time)
out = model(x, mels)

assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), 2 ** n_bits)

def test_mol(self):
"""Validate the output dimensions of a _WaveRNN model in mol mode.
"""

upsample_scales = [5, 5, 8]
n_rnn = 512
n_fc = 512
n_bits = 9
sample_rate = 24000
hop_length = 200
n_batch = 2
n_time = 200
n_freq = 100
n_output = 256
n_res_block = 10
n_hidden = 128
kernel_size = 5
mode = 'mol'

model = _WaveRNN(upsample_scales, n_bits, sample_rate, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output, mode)

x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
mels = torch.rand(n_batch, 1, n_freq, n_time)
out = model(x, mels)

assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), 30)
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes)
46 changes: 19 additions & 27 deletions torchaudio/models/_wavernn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class _MelResNet(nn.Module):
Args:
n_res_block: the number of ResBlock in stack (default=10)
n_freq: the number of bins in a spectrogram (default=128)
n_hidden: the number of hidden dimensions (default=128)
n_output: the number of output dimensions (default=128)
n_hidden: the number of hidden dimensions of resblock (default=128)
n_output: the number of output dimensions of melresnet (default=128)
kernel_size: the number of kernel size in the first Conv1d layer (default=5)

Examples
Expand Down Expand Up @@ -132,8 +132,8 @@ class _UpsampleNetwork(nn.Module):
upsample_scales: the list of upsample scales
n_res_block: the number of ResBlock in stack (default=10)
n_freq: the number of bins in a spectrogram (default=128)
n_hidden: the number of hidden dimensions (default=128)
n_output: the number of output dimensions (default=128)
n_hidden: the number of hidden dimensions of resblock (default=128)
n_output: the number of output dimensions of melresnet (default=128)
kernel_size: the number of kernel size in the first Conv1d layer (default=5)

Examples
Expand Down Expand Up @@ -205,64 +205,56 @@ class _WaveRNN(nn.Module):

Args:
upsample_scales: the list of upsample scales
n_bits: the bits of output waveform
sample_rate: the rate of audio dimensions (samples per second)
n_classes: the number of output classes
hop_length: the number of samples between the starts of consecutive frames
n_res_block: the number of ResBlock in stack (default=10)
n_rnn: the dimension of RNN layer (default=512)
n_fc: the dimension of fully connected layer (default=512)
kernel_size: the number of kernel size in the first Conv1d layer (default=5)
n_freq: the number of bins in a spectrogram (default=128)
n_hidden: the number of hidden dimensions (default=128)
n_output: the number of output dimensions (default=128)
mode: the mode of waveform in ['waveform', 'mol'] (default='waveform')
n_hidden: the number of hidden dimensions of resblock (default=128)
n_output: the number of output dimensions of melresnet (default=128)

Example
>>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_bits=9, sample_rate=24000, hop_length=200)
>>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200)
>>> waveform, sample_rate = torchaudio.load(file)
>>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
>>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time)
>>> output = wavernn(waveform, specgram)
>>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits)
>>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes)
"""

def __init__(self,
upsample_scales: List[int],
n_bits: int,
sample_rate: int,
n_classes: int,
hop_length: int,
n_res_block: int = 10,
n_rnn: int = 512,
n_fc: int = 512,
kernel_size: int = 5,
n_freq: int = 128,
n_hidden: int = 128,
n_output: int = 128,
mode: str = 'waveform') -> None:
n_output: int = 128) -> None:
super().__init__()

self.mode = mode
self.kernel_size = kernel_size

if self.mode == 'waveform':
self.n_classes = 2 ** n_bits
elif self.mode == 'mol':
self.n_classes = 30
else:
raise ValueError(f"Expected mode: `waveform` or `mol`, but found {self.mode}")

self.n_rnn = n_rnn
self.n_aux = n_output // 4
self.hop_length = hop_length
self.sample_rate = sample_rate
self.n_classes = n_classes

total_scale = 1
for upsample_scale in upsample_scales:
total_scale *= upsample_scale
if total_scale != self.hop_length:
raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}")

self.upsample = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
self.upsample = _UpsampleNetwork(upsample_scales,
n_res_block,
n_freq,
n_hidden,
n_output,
kernel_size)
self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn)

self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True)
Expand All @@ -283,7 +275,7 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
specgram: the input spectrogram to the _WaveRNN layer (n_batch, 1, n_freq, n_time)

Return:
Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits)
Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes)
"""

assert waveform.size(1) == 1, 'Require the input channel of waveform is 1'
Expand Down