-
Notifications
You must be signed in to change notification settings - Fork 739
Add WaveRNN Model #735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add WaveRNN Model #735
Changes from all commits
9aacb78
2f91a9b
63391d6
0346f23
27e26aa
6981d1c
c41ac8f
c343166
17455a3
2c44bc7
634bc7f
6a2b8a7
b547482
978e101
01fbbda
0ed6da8
1cb02fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,10 @@ | ||
| from typing import List | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
| from torch import nn | ||
|
|
||
| __all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork"] | ||
| __all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork", "_WaveRNN"] | ||
|
|
||
|
|
||
| class _ResBlock(nn.Module): | ||
|
|
@@ -192,3 +193,139 @@ def forward(self, specgram: Tensor) -> Tensor: | |
| upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent:-self.indent] | ||
|
|
||
| return upsampling_output, resnet_output | ||
|
|
||
|
|
||
| class _WaveRNN(nn.Module): | ||
| r"""WaveRNN model based on the implementation from `fatchord <https://github.com/fatchord/WaveRNN>`_. | ||
|
|
||
| The original implementation was introduced in | ||
| `"Efficient Neural Audio Synthesis" <https://arxiv.org/pdf/1802.08435.pdf>`_. | ||
| The input channels of waveform and spectrogram have to be 1. The product of | ||
vincentqb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| `upsample_scales` must equal `hop_length`. | ||
|
|
||
| Args: | ||
| upsample_scales: the list of upsample scales | ||
| n_bits: the bits of output waveform | ||
| sample_rate: the rate of audio dimensions (samples per second) | ||
| hop_length: the number of samples between the starts of consecutive frames | ||
| n_res_block: the number of ResBlock in stack (default=10) | ||
| n_rnn: the dimension of RNN layer (default=512) | ||
| n_fc: the dimension of fully connected layer (default=512) | ||
| kernel_size: the number of kernel size in the first Conv1d layer (default=5) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| n_freq: the number of bins in a spectrogram (default=128) | ||
| n_hidden: the number of hidden dimensions (default=128) | ||
| n_output: the number of output dimensions (default=128) | ||
| mode: the mode of waveform in ['waveform', 'mol'] (default='waveform') | ||
|
|
||
| Example | ||
| >>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_bits=9, sample_rate=24000, hop_length=200) | ||
| >>> waveform, sample_rate = torchaudio.load(file) | ||
| >>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length) | ||
| >>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time) | ||
| >>> output = wavernn(waveform, specgram) | ||
| >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits) | ||
| """ | ||
|
|
||
| def __init__(self, | ||
| upsample_scales: List[int], | ||
| n_bits: int, | ||
| sample_rate: int, | ||
| hop_length: int, | ||
| n_res_block: int = 10, | ||
| n_rnn: int = 512, | ||
| n_fc: int = 512, | ||
| kernel_size: int = 5, | ||
| n_freq: int = 128, | ||
| n_hidden: int = 128, | ||
| n_output: int = 128, | ||
| mode: str = 'waveform') -> None: | ||
| super().__init__() | ||
|
|
||
| self.mode = mode | ||
| self.kernel_size = kernel_size | ||
|
|
||
| if self.mode == 'waveform': | ||
| self.n_classes = 2 ** n_bits | ||
| elif self.mode == 'mol': | ||
| self.n_classes = 30 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you throw error when
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated in this PR.
Comment on lines
+247
to
+250
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's replace cc comment
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I changed the |
||
| else: | ||
| raise ValueError(f"Expected mode: `waveform` or `mol`, but found {self.mode}") | ||
|
|
||
| self.n_rnn = n_rnn | ||
| self.n_aux = n_output // 4 | ||
| self.hop_length = hop_length | ||
| self.sample_rate = sample_rate | ||
|
|
||
| total_scale = 1 | ||
| for upsample_scale in upsample_scales: | ||
| total_scale *= upsample_scale | ||
| if total_scale != self.hop_length: | ||
| raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}") | ||
|
|
||
| self.upsample = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's change to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I updated the name as |
||
| self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn) | ||
|
|
||
| self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True) | ||
| self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True) | ||
|
|
||
| self.relu1 = nn.ReLU(inplace=True) | ||
| self.relu2 = nn.ReLU(inplace=True) | ||
|
|
||
| self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc) | ||
| self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc) | ||
| self.fc3 = nn.Linear(n_fc, self.n_classes) | ||
|
|
||
| def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: | ||
| r"""Pass the input through the _WaveRNN model. | ||
|
|
||
| Args: | ||
| waveform: the input waveform to the _WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length) | ||
| specgram: the input spectrogram to the _WaveRNN layer (n_batch, 1, n_freq, n_time) | ||
|
|
||
| Return: | ||
| Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits) | ||
| """ | ||
|
|
||
| assert waveform.size(1) == 1, 'Require the input channel of waveform is 1' | ||
| assert specgram.size(1) == 1, 'Require the input channel of specgram is 1' | ||
| # remove channel dimension until the end | ||
| waveform, specgram = waveform.squeeze(1), specgram.squeeze(1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment has been added. |
||
|
|
||
| batch_size = waveform.size(0) | ||
| h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device) | ||
| h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device) | ||
| # output of upsample: | ||
| # specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale) | ||
| # aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale) | ||
| specgram, aux = self.upsample(specgram) | ||
| specgram = specgram.transpose(1, 2) | ||
| aux = aux.transpose(1, 2) | ||
|
|
||
| aux_idx = [self.n_aux * i for i in range(5)] | ||
| a1 = aux[:, :, aux_idx[0]:aux_idx[1]] | ||
| a2 = aux[:, :, aux_idx[1]:aux_idx[2]] | ||
| a3 = aux[:, :, aux_idx[2]:aux_idx[3]] | ||
| a4 = aux[:, :, aux_idx[3]:aux_idx[4]] | ||
|
|
||
| x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1) | ||
| x = self.fc(x) | ||
| res = x | ||
| x, _ = self.rnn1(x, h1) | ||
|
|
||
| x = x + res | ||
| res = x | ||
| x = torch.cat([x, a2], dim=-1) | ||
| x, _ = self.rnn2(x, h2) | ||
|
|
||
| x = x + res | ||
| x = torch.cat([x, a3], dim=-1) | ||
| x = self.fc1(x) | ||
| x = self.relu1(x) | ||
|
|
||
| x = torch.cat([x, a4], dim=-1) | ||
| x = self.fc2(x) | ||
| x = self.relu2(x) | ||
| x = self.fc3(x) | ||
|
|
||
| # bring back channel dimension | ||
| return x.unsqueeze(1) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, explanation of test is required. otherwise it will be difficult to make proper changes to this test later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test comments added.
I updated this change to each stack PR (#751, #724, #735 )