|
1 | 1 | from typing import List |
2 | 2 |
|
| 3 | +import torch |
3 | 4 | from torch import Tensor |
4 | 5 | from torch import nn |
5 | 6 |
|
6 | | -__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork"] |
| 7 | +__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork", "_WaveRNN"] |
7 | 8 |
|
8 | 9 |
|
9 | 10 | class _ResBlock(nn.Module): |
@@ -192,3 +193,139 @@ def forward(self, specgram: Tensor) -> Tensor: |
192 | 193 | upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent:-self.indent] |
193 | 194 |
|
194 | 195 | return upsampling_output, resnet_output |
| 196 | + |
| 197 | + |
| 198 | +class _WaveRNN(nn.Module): |
| 199 | + r"""WaveRNN model based on the implementation from `fatchord <https://github.com/fatchord/WaveRNN>`_. |
| 200 | +
|
| 201 | + The original implementation was introduced in |
| 202 | + `"Efficient Neural Audio Synthesis" <https://arxiv.org/pdf/1802.08435.pdf>`_. |
| 203 | + The input channels of waveform and spectrogram have to be 1. The product of |
| 204 | + `upsample_scales` must equal `hop_length`. |
| 205 | +
|
| 206 | + Args: |
| 207 | + upsample_scales: the list of upsample scales |
| 208 | + n_bits: the bits of output waveform |
| 209 | + sample_rate: the rate of audio dimensions (samples per second) |
| 210 | + hop_length: the number of samples between the starts of consecutive frames |
| 211 | + n_res_block: the number of ResBlock in stack (default=10) |
| 212 | + n_rnn: the dimension of RNN layer (default=512) |
| 213 | + n_fc: the dimension of fully connected layer (default=512) |
| 214 | + kernel_size: the number of kernel size in the first Conv1d layer (default=5) |
| 215 | + n_freq: the number of bins in a spectrogram (default=128) |
| 216 | + n_hidden: the number of hidden dimensions (default=128) |
| 217 | + n_output: the number of output dimensions (default=128) |
| 218 | + mode: the mode of waveform in ['waveform', 'mol'] (default='waveform') |
| 219 | +
|
| 220 | + Example |
| 221 | + >>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_bits=9, sample_rate=24000, hop_length=200) |
| 222 | + >>> waveform, sample_rate = torchaudio.load(file) |
| 223 | + >>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length) |
| 224 | + >>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time) |
| 225 | + >>> output = wavernn(waveform, specgram) |
| 226 | + >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits) |
| 227 | + """ |
| 228 | + |
| 229 | + def __init__(self, |
| 230 | + upsample_scales: List[int], |
| 231 | + n_bits: int, |
| 232 | + sample_rate: int, |
| 233 | + hop_length: int, |
| 234 | + n_res_block: int = 10, |
| 235 | + n_rnn: int = 512, |
| 236 | + n_fc: int = 512, |
| 237 | + kernel_size: int = 5, |
| 238 | + n_freq: int = 128, |
| 239 | + n_hidden: int = 128, |
| 240 | + n_output: int = 128, |
| 241 | + mode: str = 'waveform') -> None: |
| 242 | + super().__init__() |
| 243 | + |
| 244 | + self.mode = mode |
| 245 | + self.kernel_size = kernel_size |
| 246 | + |
| 247 | + if self.mode == 'waveform': |
| 248 | + self.n_classes = 2 ** n_bits |
| 249 | + elif self.mode == 'mol': |
| 250 | + self.n_classes = 30 |
| 251 | + else: |
| 252 | + raise ValueError(f"Expected mode: `waveform` or `mol`, but found {self.mode}") |
| 253 | + |
| 254 | + self.n_rnn = n_rnn |
| 255 | + self.n_aux = n_output // 4 |
| 256 | + self.hop_length = hop_length |
| 257 | + self.sample_rate = sample_rate |
| 258 | + |
| 259 | + total_scale = 1 |
| 260 | + for upsample_scale in upsample_scales: |
| 261 | + total_scale *= upsample_scale |
| 262 | + if total_scale != self.hop_length: |
| 263 | + raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}") |
| 264 | + |
| 265 | + self.upsample = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) |
| 266 | + self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn) |
| 267 | + |
| 268 | + self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True) |
| 269 | + self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True) |
| 270 | + |
| 271 | + self.relu1 = nn.ReLU(inplace=True) |
| 272 | + self.relu2 = nn.ReLU(inplace=True) |
| 273 | + |
| 274 | + self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc) |
| 275 | + self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc) |
| 276 | + self.fc3 = nn.Linear(n_fc, self.n_classes) |
| 277 | + |
| 278 | + def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: |
| 279 | + r"""Pass the input through the _WaveRNN model. |
| 280 | +
|
| 281 | + Args: |
| 282 | + waveform: the input waveform to the _WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length) |
| 283 | + specgram: the input spectrogram to the _WaveRNN layer (n_batch, 1, n_freq, n_time) |
| 284 | +
|
| 285 | + Return: |
| 286 | + Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits) |
| 287 | + """ |
| 288 | + |
| 289 | + assert waveform.size(1) == 1, 'Require the input channel of waveform is 1' |
| 290 | + assert specgram.size(1) == 1, 'Require the input channel of specgram is 1' |
| 291 | + # remove channel dimension until the end |
| 292 | + waveform, specgram = waveform.squeeze(1), specgram.squeeze(1) |
| 293 | + |
| 294 | + batch_size = waveform.size(0) |
| 295 | + h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device) |
| 296 | + h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device) |
| 297 | + # output of upsample: |
| 298 | + # specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale) |
| 299 | + # aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale) |
| 300 | + specgram, aux = self.upsample(specgram) |
| 301 | + specgram = specgram.transpose(1, 2) |
| 302 | + aux = aux.transpose(1, 2) |
| 303 | + |
| 304 | + aux_idx = [self.n_aux * i for i in range(5)] |
| 305 | + a1 = aux[:, :, aux_idx[0]:aux_idx[1]] |
| 306 | + a2 = aux[:, :, aux_idx[1]:aux_idx[2]] |
| 307 | + a3 = aux[:, :, aux_idx[2]:aux_idx[3]] |
| 308 | + a4 = aux[:, :, aux_idx[3]:aux_idx[4]] |
| 309 | + |
| 310 | + x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1) |
| 311 | + x = self.fc(x) |
| 312 | + res = x |
| 313 | + x, _ = self.rnn1(x, h1) |
| 314 | + |
| 315 | + x = x + res |
| 316 | + res = x |
| 317 | + x = torch.cat([x, a2], dim=-1) |
| 318 | + x, _ = self.rnn2(x, h2) |
| 319 | + |
| 320 | + x = x + res |
| 321 | + x = torch.cat([x, a3], dim=-1) |
| 322 | + x = self.fc1(x) |
| 323 | + x = self.relu1(x) |
| 324 | + |
| 325 | + x = torch.cat([x, a4], dim=-1) |
| 326 | + x = self.fc2(x) |
| 327 | + x = self.relu2(x) |
| 328 | + x = self.fc3(x) |
| 329 | + |
| 330 | + # bring back channel dimension |
| 331 | + return x.unsqueeze(1) |
0 commit comments