diff --git a/test/test_models.py b/test/test_models.py index 519fbc7b26..c54a57cebd 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,5 +1,5 @@ import torch -from torchaudio.models import Wav2Letter, _MelResNet, _UpsampleNetwork +from torchaudio.models import Wav2Letter, _MelResNet, _UpsampleNetwork, _WaveRNN from . import common_utils @@ -81,3 +81,62 @@ def test_waveform(self): assert out1.size() == (n_batch, n_freq, total_scale * (n_time - kernel_size + 1)) assert out2.size() == (n_batch, n_output, total_scale * (n_time - kernel_size + 1)) + + +class TestWaveRNN(common_utils.TorchaudioTestCase): + + def test_waveform(self): + """Validate the output dimensions of a _WaveRNN model in waveform mode. + """ + + upsample_scales = [5, 5, 8] + n_rnn = 512 + n_fc = 512 + n_bits = 9 + sample_rate = 24000 + hop_length = 200 + n_batch = 2 + n_time = 200 + n_freq = 100 + n_output = 256 + n_res_block = 10 + n_hidden = 128 + kernel_size = 5 + mode = 'waveform' + + model = _WaveRNN(upsample_scales, n_bits, sample_rate, hop_length, n_res_block, + n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output, mode) + + x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1)) + mels = torch.rand(n_batch, 1, n_freq, n_time) + out = model(x, mels) + + assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), 2 ** n_bits) + + def test_mol(self): + """Validate the output dimensions of a _WaveRNN model in mol mode. + """ + + upsample_scales = [5, 5, 8] + n_rnn = 512 + n_fc = 512 + n_bits = 9 + sample_rate = 24000 + hop_length = 200 + n_batch = 2 + n_time = 200 + n_freq = 100 + n_output = 256 + n_res_block = 10 + n_hidden = 128 + kernel_size = 5 + mode = 'mol' + + model = _WaveRNN(upsample_scales, n_bits, sample_rate, hop_length, n_res_block, + n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output, mode) + + x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1)) + mels = torch.rand(n_batch, 1, n_freq, n_time) + out = model(x, mels) + + assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), 30) diff --git a/torchaudio/models/_wavernn.py b/torchaudio/models/_wavernn.py index 1df9eb0637..cd2e89a10c 100644 --- a/torchaudio/models/_wavernn.py +++ b/torchaudio/models/_wavernn.py @@ -1,9 +1,10 @@ from typing import List +import torch from torch import Tensor from torch import nn -__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork"] +__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork", "_WaveRNN"] class _ResBlock(nn.Module): @@ -192,3 +193,139 @@ def forward(self, specgram: Tensor) -> Tensor: upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent:-self.indent] return upsampling_output, resnet_output + + +class _WaveRNN(nn.Module): + r"""WaveRNN model based on the implementation from `fatchord `_. + + The original implementation was introduced in + `"Efficient Neural Audio Synthesis" `_. + The input channels of waveform and spectrogram have to be 1. The product of + `upsample_scales` must equal `hop_length`. + + Args: + upsample_scales: the list of upsample scales + n_bits: the bits of output waveform + sample_rate: the rate of audio dimensions (samples per second) + hop_length: the number of samples between the starts of consecutive frames + n_res_block: the number of ResBlock in stack (default=10) + n_rnn: the dimension of RNN layer (default=512) + n_fc: the dimension of fully connected layer (default=512) + kernel_size: the number of kernel size in the first Conv1d layer (default=5) + n_freq: the number of bins in a spectrogram (default=128) + n_hidden: the number of hidden dimensions (default=128) + n_output: the number of output dimensions (default=128) + mode: the mode of waveform in ['waveform', 'mol'] (default='waveform') + + Example + >>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_bits=9, sample_rate=24000, hop_length=200) + >>> waveform, sample_rate = torchaudio.load(file) + >>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length) + >>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time) + >>> output = wavernn(waveform, specgram) + >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits) + """ + + def __init__(self, + upsample_scales: List[int], + n_bits: int, + sample_rate: int, + hop_length: int, + n_res_block: int = 10, + n_rnn: int = 512, + n_fc: int = 512, + kernel_size: int = 5, + n_freq: int = 128, + n_hidden: int = 128, + n_output: int = 128, + mode: str = 'waveform') -> None: + super().__init__() + + self.mode = mode + self.kernel_size = kernel_size + + if self.mode == 'waveform': + self.n_classes = 2 ** n_bits + elif self.mode == 'mol': + self.n_classes = 30 + else: + raise ValueError(f"Expected mode: `waveform` or `mol`, but found {self.mode}") + + self.n_rnn = n_rnn + self.n_aux = n_output // 4 + self.hop_length = hop_length + self.sample_rate = sample_rate + + total_scale = 1 + for upsample_scale in upsample_scales: + total_scale *= upsample_scale + if total_scale != self.hop_length: + raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}") + + self.upsample = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) + self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn) + + self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True) + self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True) + + self.relu1 = nn.ReLU(inplace=True) + self.relu2 = nn.ReLU(inplace=True) + + self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc) + self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc) + self.fc3 = nn.Linear(n_fc, self.n_classes) + + def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: + r"""Pass the input through the _WaveRNN model. + + Args: + waveform: the input waveform to the _WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length) + specgram: the input spectrogram to the _WaveRNN layer (n_batch, 1, n_freq, n_time) + + Return: + Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits) + """ + + assert waveform.size(1) == 1, 'Require the input channel of waveform is 1' + assert specgram.size(1) == 1, 'Require the input channel of specgram is 1' + # remove channel dimension until the end + waveform, specgram = waveform.squeeze(1), specgram.squeeze(1) + + batch_size = waveform.size(0) + h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device) + h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device) + # output of upsample: + # specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale) + # aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale) + specgram, aux = self.upsample(specgram) + specgram = specgram.transpose(1, 2) + aux = aux.transpose(1, 2) + + aux_idx = [self.n_aux * i for i in range(5)] + a1 = aux[:, :, aux_idx[0]:aux_idx[1]] + a2 = aux[:, :, aux_idx[1]:aux_idx[2]] + a3 = aux[:, :, aux_idx[2]:aux_idx[3]] + a4 = aux[:, :, aux_idx[3]:aux_idx[4]] + + x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1) + x = self.fc(x) + res = x + x, _ = self.rnn1(x, h1) + + x = x + res + res = x + x = torch.cat([x, a2], dim=-1) + x, _ = self.rnn2(x, h2) + + x = x + res + x = torch.cat([x, a3], dim=-1) + x = self.fc1(x) + x = self.relu1(x) + + x = torch.cat([x, a4], dim=-1) + x = self.fc2(x) + x = self.relu2(x) + x = self.fc3(x) + + # bring back channel dimension + return x.unsqueeze(1)